diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 14aa41bef349..96f4e55fb12d 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -23,7 +23,7 @@ static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter, int64_t dimA, int64_t dimB, Value &transposed) { Type transposedType; - if (failed(getTransposedType(input.getType().cast(), + if (failed(getTransposedType(cast(input.getType()), dimA, dimB, transposedType))) return failure(); Value cstDimA = rewriter.create( @@ -554,7 +554,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( // conversions which are not supported in Torch-MLIR right now. Torch::ValueTensorType targetTy = - target.getType().cast(); + cast(target.getType()); if (!targetTy.hasDtype()) { return rewriter.notifyMatchFailure(binder.op, "target tensor must have a dtype"); @@ -753,9 +753,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.tensorResultType(resultType)) return failure(); Type listElemType = - tensors[0] - .getType() - .cast() + cast(tensors[0].getType()) .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr); Type listType = Torch::ListType::get(listElemType); @@ -869,7 +867,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.tensorResultType(resultType)) return failure(); - auto weightTensorType = weight.getType().cast(); + auto weightTensorType = cast(weight.getType()); if (!weightTensorType || !weightTensorType.hasSizes()) { return rewriter.notifyMatchFailure( binder.op, "Expected weight type having sizes"); @@ -1188,7 +1186,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.tensorResultType(resultType)) return failure(); - auto weightTensorType = weight.getType().cast(); + auto weightTensorType = cast(weight.getType()); if (!weightTensorType || !weightTensorType.hasSizes()) { return rewriter.notifyMatchFailure( binder.op, "Expected weight type having sizes"); @@ -1427,7 +1425,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.customOpNameStringAttr(mode, "mode", "DCR") || binder.tensorResultType(resultType)) return failure(); - auto inputTy = input.getType().dyn_cast(); + auto inputTy = dyn_cast(input.getType()); if (!inputTy || !inputTy.hasSizes()) { return rewriter.notifyMatchFailure( binder.op, "Expected input type having sizes"); @@ -1536,9 +1534,9 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Value scale = operands[1]; Value zeropoint = operands[2]; - auto operandTy = operand.getType().cast(); + auto operandTy = cast(operand.getType()); - auto scaleTy = scale.getType().dyn_cast(); + auto scaleTy = dyn_cast(scale.getType()); if (!scaleTy || !scaleTy.hasSizes()) return rewriter.notifyMatchFailure(binder.op, "requires known rank"); if (!resultType.hasDtype()) @@ -1611,7 +1609,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( ratio = rewriter.create(loc, operands[1]); Value trainVal = operands[2]; auto trainTensorType = - trainVal.getType().dyn_cast(); + dyn_cast(trainVal.getType()); if (!trainTensorType) return rewriter.notifyMatchFailure(binder.op, "train tensor must have a type"); @@ -1629,8 +1627,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( if (auto valueTensorLiteralOp = trainVal.getDefiningOp()) { - auto val = valueTensorLiteralOp.getValue() - .cast() + auto val = cast(valueTensorLiteralOp.getValue()) .getSplatValue(); trainingMode = rewriter.create(loc, val); } else { @@ -2072,7 +2069,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( dyn_cast(shape.getType()).getSizes(); SmallVector dimList; Torch::BaseTensorType shapeType = - shape.getType().cast(); + cast(shape.getType()); Type selectResultType = rewriter.getType( ArrayRef({}), shapeType.getOptionalDtype()); Value zero = rewriter.create( diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index c0df70b1206d..90c64db33b01 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -104,10 +104,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return rewriter.notifyMatchFailure( binder.op, "operand grid_sampler bind failure"); - auto inputTensorType = input.getType().cast(); + auto inputTensorType = cast(input.getType()); ArrayRef inputShape = inputTensorType.getSizes(); uint32_t inputRank = inputShape.size(); - auto gridTensorType = grid.getType().cast(); + auto gridTensorType = cast(grid.getType()); ArrayRef gridShape = gridTensorType.getSizes(); uint32_t gridRank = gridShape.size(); @@ -233,7 +233,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( axis = rank + axis; } // need input type and sizes to flatten/unflatten later. - auto inputTy = input.getType().cast(); + auto inputTy = cast(input.getType()); if (!inputTy || !inputTy.hasSizes()) return rewriter.notifyMatchFailure( binder.op, "failed to get input type or sizes"); @@ -1065,7 +1065,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); auto transpose = [&](Value m) -> Value { - auto tty = m.getType().cast(); + auto tty = cast(m.getType()); auto shape = tty.getOptionalSizes(); if (shape.has_value()) { llvm::SmallVector newShape(shape.value()); @@ -1134,7 +1134,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.tensorResultType(resultType)) return failure(); - auto inputTensorType = operand.getType().cast(); + auto inputTensorType = cast(operand.getType()); if (!inputTensorType || !inputTensorType.hasSizes()) { return rewriter.notifyMatchFailure( binder.op, "Expected input type having sizes"); @@ -1228,7 +1228,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( rank = *maybeRank; SmallVector normalized; axis = Torch::toPositiveDim(axis, rank); - auto xType = x.getType().cast(); + auto xType = cast(x.getType()); if (!xType.hasSizes()) { return rewriter.notifyMatchFailure( binder.op, "Expected input (X) to have sizes"); @@ -1307,7 +1307,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( // Get pads shape and rank. The pads tensor is expected to be 1-D // tensor. - auto padsTensorType = pads.getType().cast(); + auto padsTensorType = cast(pads.getType()); if (!padsTensorType || !padsTensorType.hasSizes()) { return rewriter.notifyMatchFailure(binder.op, "Expect non empty pad tensor"); @@ -1323,7 +1323,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( // As per onnx.Pad documentation, padSize = 2*num_data_axes // (if axes param not passed). Need to be updated when adding // support for `axes` param. - auto dataOpTy = data.getType().cast(); + auto dataOpTy = cast(data.getType()); TensorType dataTensor = dataOpTy.toBuiltinTensor(); if (!dataTensor || !dataTensor.hasRank()) return rewriter.notifyMatchFailure( @@ -1350,7 +1350,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } if (!constantValue) { - auto dataTensorType = data.getType().cast(); + auto dataTensorType = cast(data.getType()); if (dataTensorType.getDtype().isa()) constantValue = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 7d34f14b4e70..586b8d4ff053 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -54,7 +54,7 @@ LogicalResult reducedSumImpl(OpBinder binder, SmallVector axesList; Value axesVal; if (!binder.tensorOperandAtIndex(axesVal, 1)) { - auto inputType = data.getType().dyn_cast(); + auto inputType = dyn_cast(data.getType()); if (!inputType.hasSizes() || !resultType.hasSizes()) { return rewriter.notifyMatchFailure( binder.op, "unimplemented: expected input and result to have shapes"); @@ -97,7 +97,7 @@ LogicalResult reducedSumImpl(OpBinder binder, } if (axesList.empty()) { Torch::BaseTensorType axesType = - axesVal.getType().cast(); + cast(axesVal.getType()); auto axesTy = dyn_cast(axesVal.getType()); auto axesShape = axesTy.getSizes(); if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize) @@ -177,7 +177,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value scale = operands[1]; Value zeropoint = operands[2]; - auto scaleTy = scale.getType().dyn_cast(); + auto scaleTy = dyn_cast(scale.getType()); if (!scaleTy || !scaleTy.hasSizes()) return rewriter.notifyMatchFailure(binder.op, "requires known rank"); if (!resultType.hasDtype()) @@ -241,7 +241,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value c = operands.size() == 9 ? operands[8] : nullptr; auto check = [](Value v) { - auto vTy = v.getType().cast(); + auto vTy = cast(v.getType()); return llvm::all_of(vTy.getSizes(), [](int64_t d) { return d == 1; }); }; if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) || @@ -250,7 +250,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, "not supported for non per-tensor quantization"); auto extract = [&rewriter, &binder](Value v) { - auto vTy = v.getType().cast(); + auto vTy = cast(v.getType()); Type extractTy = rewriter.getType(); if (isa(vTy.getDtype())) extractTy = rewriter.getType(); @@ -268,7 +268,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( auto make = [&rewriter, &binder](Value v, Value scale, Value zp) -> Value { - auto ty = v.getType().cast(); + auto ty = cast(v.getType()); auto newTy = getQTorchTypeFromTorchIntType(ty); return rewriter.create( binder.getLoc(), newTy, v, scale, zp); @@ -351,7 +351,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value cZp = operands[7]; auto check = [](Value v) { - auto vTy = v.getType().cast(); + auto vTy = cast(v.getType()); for (auto dim : vTy.getSizes()) if (dim != 1) return false; @@ -368,7 +368,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( rewriter.getType()), ValueRange{}); auto extract = [&rewriter, &binder, &emptyList](Value v) { - auto vTy = v.getType().cast(); + auto vTy = cast(v.getType()); if (!vTy.getSizes().empty()) { vTy = rewriter.getType( ArrayRef({}), vTy.getOptionalDtype()); @@ -393,7 +393,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( auto make = [&rewriter, &binder](Value v, Value scale, Value zp) -> Value { - auto ty = v.getType().cast(); + auto ty = cast(v.getType()); auto newTy = getQTorchTypeFromTorchIntType(ty); return rewriter.create( binder.getLoc(), newTy, v, scale, zp); @@ -667,7 +667,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return failure(); Value data = inputOperands[0]; - auto inputType = data.getType().dyn_cast(); + auto inputType = dyn_cast(data.getType()); if (!inputType.hasSizes() || !resultType.hasSizes()) return rewriter.notifyMatchFailure( binder.op, @@ -718,7 +718,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( if (dimList.empty()) { Value axes = inputOperands[1]; Torch::BaseTensorType axesType = - axes.getType().cast(); + cast(axes.getType()); SmallVector selectSizes{1}; Type selectResultType = axesType.getWithSizesAndDtype( selectSizes, axesType.getOptionalDtype()); @@ -760,7 +760,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( if (binder.tensorOperands(data, axes) || binder.tensorResultType(resultType)) return failure(); - auto inputType = data.getType().dyn_cast(); + auto inputType = dyn_cast(data.getType()); if (!inputType.hasSizes() || !resultType.hasSizes()) return rewriter.notifyMatchFailure( binder.op, @@ -925,8 +925,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // Perform an AtenToDtype op on the squared sum of the operand, stored // now in operand itself. - auto size = operand.getType() - .dyn_cast() + auto size = dyn_cast(operand.getType()) .getOptionalSizes(); auto f32ResultType = rewriter.getType( size, rewriter.getF32Type()); @@ -1005,7 +1004,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value axesVal; if (!binder.tensorOperandAtIndex(axesVal, 1)) { - auto inputType = data.getType().dyn_cast(); + auto inputType = dyn_cast(data.getType()); if (!inputType.hasSizes() || !resultType.hasSizes()) { return rewriter.notifyMatchFailure( binder.op, @@ -1053,7 +1052,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( if (axesList.empty()) { Torch::BaseTensorType axesType = - axesVal.getType().cast(); + cast(axesVal.getType()); auto axesTy = dyn_cast(axesVal.getType()); auto axesShape = axesTy.getSizes(); if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize) @@ -1191,7 +1190,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // Extract the axes values from the axes operand: if (!binder.tensorOperandAtIndex(axes, 1)) { Torch::BaseTensorType axesType = - axes.getType().cast(); + cast(axes.getType()); SmallVector selectSizes{1}; Type selectResultType = axesType.getWithSizesAndDtype( selectSizes, axesType.getOptionalDtype()); @@ -1344,7 +1343,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // Extract the axes values from the axes operand: if (!binder.tensorOperandAtIndex(axes, 1)) { Torch::BaseTensorType axesType = - axes.getType().cast(); + cast(axes.getType()); SmallVector selectSizes{1}; Type selectResultType = axesType.getWithSizesAndDtype( selectSizes, axesType.getOptionalDtype()); @@ -1467,12 +1466,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( auto loc = binder.getLoc(); auto result0Ty = - binder.op->getResult(0).getType().cast(); - auto resultNTy = binder.op->getResults() - .back() - .getType() - .cast(); - auto selfTy = self.getType().cast(); + cast(binder.op->getResult(0).getType()); + auto resultNTy = cast( + binder.op->getResults().back().getType()); + auto selfTy = cast(self.getType()); int64_t dim = axis; if (dim < 0) @@ -1555,7 +1552,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, "Failed to get num_outputs attribute"); auto result0Ty = - binder.op->getResult(0).getType().cast(); + cast(binder.op->getResult(0).getType()); auto selfTy = cast(binder.op->getOperand(0).getType()); @@ -1617,7 +1614,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) return failure(); - auto operandType = operand.getType().cast(); + auto operandType = cast(operand.getType()); TensorType tensorType = operandType.toBuiltinTensor(); if (!tensorType || !tensorType.hasRank()) return failure(); @@ -1705,26 +1702,25 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } auto context = rewriter.getContext(); - auto operandTorchTy = operand.getType().cast(); + auto operandTorchTy = cast(operand.getType()); auto operandTy = - operandTorchTy.toBuiltinTensor().dyn_cast(); + dyn_cast(operandTorchTy.toBuiltinTensor()); if (!operandTy) return rewriter.notifyMatchFailure( binder.op, "Expected tensor operator argument to be a ranked tensor type"); - auto startsTorchTy = starts.getType().cast(); + auto startsTorchTy = cast(starts.getType()); auto startsTy = - startsTorchTy.toBuiltinTensor().dyn_cast(); + dyn_cast(startsTorchTy.toBuiltinTensor()); int startSize = startsTy.getDimSize(0); - auto endsTorchTy = ends.getType().cast(); - auto endsTy = - endsTorchTy.toBuiltinTensor().dyn_cast(); + auto endsTorchTy = cast(ends.getType()); + auto endsTy = dyn_cast(endsTorchTy.toBuiltinTensor()); int endSize = endsTy.getDimSize(0); auto resultTy = - resultTorchType.toBuiltinTensor().dyn_cast(); + dyn_cast(resultTorchType.toBuiltinTensor()); if (!resultTy) return rewriter.notifyMatchFailure( binder.op, "Expected result type to be a ranked tensor type"); @@ -1768,9 +1764,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "and their dimensions to match"); if (axes) { - auto axesTorchTy = axes.getType().cast(); + auto axesTorchTy = cast(axes.getType()); auto axesTy = - axesTorchTy.toBuiltinTensor().dyn_cast(); + dyn_cast(axesTorchTy.toBuiltinTensor()); int64_t numAxes = axesTy.getDimSize(0); if (!(axesTy && numAxes == endSize)) @@ -1792,7 +1788,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); auto select = [&](Value v, Value k) -> Value { - auto ty = v.getType().cast(); + auto ty = cast(v.getType()); auto sel = rewriter.create( loc, Torch::ValueTensorType::get(ty.getContext(), ArrayRef{1}, @@ -1872,7 +1868,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } Torch::BaseTensorType shapeType = - shape.getType().cast(); + cast(shape.getType()); SmallVector dimList; SmallVector selectSizes; selectSizes.push_back(1); @@ -2007,7 +2003,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // instead of using the dynamic axes at operand[1]. if (!binder.tensorOperandAtIndex(axes, 1)) { Torch::BaseTensorType axesType = - axes.getType().cast(); + cast(axes.getType()); auto sizes = axesType.getSizes(); for (int i = 0; i < sizes[0]; i++) { Value selectIndex = rewriter.create( @@ -2136,7 +2132,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // int32, int64 Assuming start, limit and delta to be same type (could // they be different?) Torch::BaseTensorType startTensorType = - start.getType().cast(); + cast(start.getType()); bool isFloatDType = startTensorType.getDtype().isF64() || startTensorType.getDtype().isF32(); bool isIntDType = startTensorType.getDtype().isInteger(16) || @@ -2222,7 +2218,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( SmallVector selectSizes; selectSizes.push_back(1); Torch::BaseTensorType shapeType = - repeatDims.getType().cast(); + cast(repeatDims.getType()); Type selectResultType = shapeType.getWithSizesAndDtype( llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype()); Value zero = rewriter.create( diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index abd119fc0ac5..2703d48724cf 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -95,7 +95,7 @@ class ConvertAtenUnaryOpToFloatMathOp : public OpConversionPattern { Value input = adaptor.getA(); Type resultType = this->getTypeConverter()->convertType(op->getResult(0).getType()); - if (!input.getType().isa()) + if (!isa(input.getType())) input = convertScalarToDtype(rewriter, loc, input, rewriter.getF64Type()); Value result = rewriter.create(loc, input); rewriter.replaceOp(op, @@ -172,8 +172,8 @@ class ConvertTorchTensorLiteralOp matchAndRewrite(ValueTensorLiteralOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { MLIRContext *context = op->getContext(); - if (auto elements = op.getValueAttr().dyn_cast()) { - if (auto type = elements.getType().dyn_cast()) { + if (auto elements = dyn_cast(op.getValueAttr())) { + if (auto type = dyn_cast(elements.getType())) { Type elemTy = op.getValueAttr().getElementType(); unsigned bitWidth = elemTy.getIntOrFloatBitWidth(); Type builtinTensorElemTy = IntegerType::get(context, bitWidth); @@ -187,9 +187,9 @@ class ConvertTorchTensorLiteralOp } } if (auto elements = - op.getValueAttr().dyn_cast()) { - if (auto type = elements.getType().dyn_cast()) { - if (auto intType = type.getElementType().dyn_cast()) { + dyn_cast(op.getValueAttr())) { + if (auto type = dyn_cast(elements.getType())) { + if (auto intType = dyn_cast(type.getElementType())) { Type builtinTensorElemTy = IntegerType::get(context, intType.getIntOrFloatBitWidth()); auto shapedType = diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index ad1a17b0ab8c..d034a8293463 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -49,8 +49,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, SmallVector &strides) { Location loc = op.getLoc(); auto input = adaptor.getSelf(); - RankedTensorType inputType = - input.getType().template cast(); + RankedTensorType inputType = cast(input.getType()); Value zero = rewriter.create(loc, 0); Value one = rewriter.create(loc, 1); @@ -73,8 +72,8 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, Value builtinTypeStart = adaptor.getStart(); Value builtinTypeEnd = adaptor.getEnd(); - if (torchTypeStart.getType().isa() || - torchTypeEnd.getType().isa()) + if (isa(torchTypeStart.getType()) || + isa(torchTypeEnd.getType())) return rewriter.notifyMatchFailure(op, "unimplemented optional type arg"); Value stepIndex = castIntToIndex(rewriter, loc, adaptor.getStep()); @@ -84,7 +83,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, // We cannot use to positive valid dim as for negative strides we need to // clamp to `-1` so that the full tensor bounds are available: Value end = builtinTypeEnd; - if (torchTypeEnd.getType().isa()) { + if (isa(torchTypeEnd.getType())) { end = dimSize; } else { end = castIntToIndex(rewriter, loc, end); @@ -594,7 +593,7 @@ class ConvertAtenFlattenUsingIntsOp int64_t endDim; if (!matchPattern(op.getEndDim(), m_TorchConstantInt(&endDim))) return rewriter.notifyMatchFailure(op, "end_dim must be constant"); - auto type = adaptor.getSelf().getType().cast(); + auto type = cast(adaptor.getSelf().getType()); auto inputRank = type.getRank(); if (inputRank == 1) { // If input rank is equal to 1, then there's no scope for flattening the @@ -604,7 +603,7 @@ class ConvertAtenFlattenUsingIntsOp } auto resultType = - getTypeConverter()->convertType(op.getType()).cast(); + cast(getTypeConverter()->convertType(op.getType())); if (startDim < 0) startDim += inputRank; if (endDim < 0) @@ -652,7 +651,7 @@ class ConvertAtenUnflattenIntOp ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value self = op.getSelf(); - BaseTensorType outputTensorType = op.getType().cast(); + BaseTensorType outputTensorType = cast(op.getType()); if (!outputTensorType.hasSizes()) return rewriter.notifyMatchFailure( op, "unimplemented: output must have known sizes"); @@ -660,7 +659,7 @@ class ConvertAtenUnflattenIntOp std::optional maybeRank = getTensorRank(self); if (!maybeRank) return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor"); - auto inputTensorType = self.getType().cast(); + auto inputTensorType = cast(self.getType()); if (!inputTensorType || !inputTensorType.hasSizes()) { return rewriter.notifyMatchFailure(op, "Expected input type having sizes"); @@ -901,7 +900,7 @@ class ConvertAtenViewOp : public OpConversionPattern { getInputAndOutputShape(Value inputTorchTensor, SmallVector outputSizeTorchInt) { SmallVector inputShape( - inputTorchTensor.getType().cast().getSizes()); + cast(inputTorchTensor.getType()).getSizes()); SmallVector outputShape(outputSizeTorchInt.size(), kUnknownSize); for (auto [outputDim, outputDimSize] : llvm::enumerate(outputSizeTorchInt)) { @@ -945,11 +944,11 @@ class ConvertAtenViewOp : public OpConversionPattern { return failure(); Location loc = op.getLoc(); Value input = adaptor.getSelf(); - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); int64_t inputRank = inputType.getRank(); const TypeConverter *typeConverter = getTypeConverter(); auto resultType = - typeConverter->convertType(op.getType()).cast(); + cast(typeConverter->convertType(op.getType())); int64_t resultRank = resultType.getRank(); if (resultRank == 0) { rewriter @@ -1349,7 +1348,7 @@ class ConvertAtenViewOpToReshape : public OpConversionPattern { auto outputDims = b.create(ty, sizes); auto resultType = - typeConverter->convertType(op.getType()).cast(); + cast(typeConverter->convertType(op.getType())); rewriter.replaceOpWithNewOp(op, resultType, self, outputDims); return success(); @@ -1367,13 +1366,13 @@ class ConvertAtenSqueezeOp : public OpConversionPattern { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Value input = adaptor.getSelf(); - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); auto inputShape = inputType.getShape(); int64_t inputRank = inputType.getRank(); const TypeConverter *typeConverter = getTypeConverter(); auto resultType = - typeConverter->convertType(op.getType()).cast(); + cast(typeConverter->convertType(op.getType())); auto resultShape = resultType.getShape(); int64_t resultRank = resultType.getRank(); @@ -1437,7 +1436,7 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Value input = adaptor.getSelf(); - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); int64_t inputRank = inputType.getRank(); if (inputRank == 0) { @@ -1460,7 +1459,7 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern { const TypeConverter *typeConverter = getTypeConverter(); auto resultType = - typeConverter->convertType(op.getType()).cast(); + cast(typeConverter->convertType(op.getType())); int64_t resultRank = resultType.getRank(); // If the dim(th) dimension of operand tensor type is not statically unit, @@ -1510,7 +1509,7 @@ class ConvertAtenUnsqueezeOp : public OpConversionPattern { if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure(op, "dim must be constant"); auto inputRank = - adaptor.getSelf().getType().cast().getRank(); + cast(adaptor.getSelf().getType()).getRank(); dim = toPositiveDim(dim, inputRank + 1); if (!isValidDim(dim, inputRank + 1)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); @@ -1535,9 +1534,8 @@ class ConvertAtenUnsqueezeOp : public OpConversionPattern { } } } - auto resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + auto resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); rewriter.replaceOpWithNewOp( op, resultType, adaptor.getSelf(), reassociationMap); return success(); @@ -1564,11 +1562,10 @@ class ConvertAtenTransposeIntOp return rewriter.notifyMatchFailure(op, "dim1 must be constant"); auto inVector = adaptor.getSelf(); - auto inType = inVector.getType().cast(); + auto inType = cast(inVector.getType()); auto inputRank = inType.getRank(); - auto outType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + auto outType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); auto elementType = inType.getElementType(); dim0 = toPositiveDim(dim0, inputRank); @@ -1634,11 +1631,10 @@ class ConvertAtenPermuteOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "all dimensions must be constant"); Value inVector = adaptor.getSelf(); - auto inType = inVector.getType().cast(); + auto inType = cast(inVector.getType()); int64_t inputRank = inType.getRank(); - auto outType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + auto outType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); Type elementType = inType.getElementType(); // Check if the dimensions are a valid constants. @@ -1747,7 +1743,7 @@ class ConvertAtenCatOp : public OpConversionPattern { getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType); RankedTensorType newResultType = - typeConverter->convertType(op.getType()).cast(); + cast(typeConverter->convertType(op.getType())); int rank = newResultType.getRank(); Value dimValue = op.getDim(); int64_t dim; @@ -1802,7 +1798,7 @@ class ConvertAtenBroadcastToOp : public OpConversionPattern { // which in this case is `inShapeConverted` because this shape will yield // us the dimension size of the output. SmallVector useBroadcastToShape; - int64_t inputRank = self.getType().cast().getRank(); + int64_t inputRank = cast(self.getType()).getRank(); for (size_t i = inShape.size() - inputRank, e = inShape.size(); i < e; ++i) { int64_t dim; @@ -1821,7 +1817,7 @@ class ConvertAtenBroadcastToOp : public OpConversionPattern { SmallVector inShapeConverted = getTypeConvertedValues( rewriter, op.getLoc(), getTypeConverter(), inShape); auto newResultType = - getTypeConverter()->convertType(op.getType()).cast(); + cast(getTypeConverter()->convertType(op.getType())); Value result; if (failed(torch_to_linalg::broadcastToGivenShape( op, rewriter, self, inShapeConverted, newResultType, result, @@ -1869,7 +1865,7 @@ class ConvertAtenCopyOp : public OpConversionPattern { Location loc = op.getLoc(); Value self = adaptor.getSelf(); Value src = adaptor.getSrc(); - RankedTensorType selfType = self.getType().cast(); + RankedTensorType selfType = cast(self.getType()); // The non_blocking should be a constant `False`. bool nonBlocking; @@ -1954,7 +1950,7 @@ class ConvertAtenSliceScatterOp } Value src = adaptor.getSrc(); - auto srcType = src.getType().cast(); + auto srcType = cast(src.getType()); int64_t srcRank = srcType.getRank(); SmallVector srcAbstractSizes(srcRank, kUnknownSize); // TODO: audit possibility of sparsity on these tensor @@ -1992,7 +1988,7 @@ class ConvertAtenViewAsComplexOp auto input = adaptor.getSelf(); RankedTensorType resultType = - typeConverter->convertType(op.getType()).cast(); + cast(typeConverter->convertType(op.getType())); auto elementType = resultType.getElementType(); SmallVector resultShape; @@ -2070,9 +2066,9 @@ class ConvertAtenViewAsRealOp : public OpConversionPattern { auto input = adaptor.getSelf(); RankedTensorType resultType = - typeConverter->convertType(op.getType()).cast(); + cast(typeConverter->convertType(op.getType())); - RankedTensorType inputType = input.getType().cast(); + RankedTensorType inputType = cast(input.getType()); auto inputElementType = getElementTypeOrSelf(input.getType()); if (!isa(inputElementType)) { return op.emitError("only ComplexType is allowed as input type"); @@ -2157,7 +2153,7 @@ class ConvertAtenDiagonalOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "dim2 must be constant"); Value inputMatrix = adaptor.getSelf(); - RankedTensorType inputType = inputMatrix.getType().cast(); + RankedTensorType inputType = cast(inputMatrix.getType()); int64_t inputRank = inputType.getRank(); if (inputRank < 2) @@ -2277,7 +2273,7 @@ class ConvertAtenDiagEmbedOp : public OpConversionPattern { static SmallVector getDiagEmbedResultShape(OpBuilder &b, Location loc, Value tensor, int64_t offset, int64_t dim1, int64_t dim2) { - auto inputType = tensor.getType().cast(); + auto inputType = cast(tensor.getType()); auto inputRank = inputType.getRank(); // output tensor always has 1 extra dimension @@ -2314,7 +2310,7 @@ class ConvertAtenDiagEmbedOp : public OpConversionPattern { Location loc = op->getLoc(); Value input = adaptor.getSelf(); - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); auto inputRank = inputType.getRank(); auto resultRank = inputRank + 1; diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index 011978a68a66..9254b1a17ab7 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -80,7 +80,7 @@ class ConvertAtenGatherOp : public OpConversionPattern { if (!matchPattern(dimValue, m_TorchConstantInt(&dim))) return op.emitError("unimplemented: dim is not constant"); int64_t inputRank = - adaptor.getSelf().getType().cast().getRank(); + cast(adaptor.getSelf().getType()).getRank(); dim = toPositiveDim(dim, inputRank); if (!isValidDim(dim, inputRank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); @@ -88,7 +88,7 @@ class ConvertAtenGatherOp : public OpConversionPattern { Value indices = adaptor.getIndex(); Value self = adaptor.getSelf(); RankedTensorType newResultTy = - getTypeConverter()->convertType(op.getType()).cast(); + cast(getTypeConverter()->convertType(op.getType())); int64_t rank = newResultTy.getRank(); SmallVector sizes = getTensorSizes(rewriter, loc, indices); @@ -128,9 +128,9 @@ class ConvertAtenEmbeddingOp : public OpConversionPattern { Value weight = adaptor.getWeight(); Value indices = adaptor.getIndices(); RankedTensorType newResultType = - typeConverter->convertType(op.getType()).cast(); + cast(typeConverter->convertType(op.getType())); - auto weightTy = weight.getType().cast(); + auto weightTy = cast(weight.getType()); if (weightTy.getRank() != 2) return rewriter.notifyMatchFailure(op, "weight must be rank 2"); Value embeddingDim = getDimOp(rewriter, loc, weight, 1); @@ -140,7 +140,7 @@ class ConvertAtenEmbeddingOp : public OpConversionPattern { sizes.push_back(embeddingDim); int64_t resultRank = sizes.size(); - auto indicesTy = indices.getType().cast(); + auto indicesTy = cast(indices.getType()); int64_t indicesRank = indicesTy.getRank(); SmallVector indicesExprs; for (int i = 0; i < indicesRank; i++) @@ -274,15 +274,15 @@ class ConvertAtenEmbeddingBagPaddingIdxOp "include_last_offset is expected to be a constant boolean value."); } - auto weightTy = weight.getType().cast(); + auto weightTy = cast(weight.getType()); if (weightTy.getRank() != 2) return rewriter.notifyMatchFailure(op, "weight must be rank 2"); - auto indicesTy = indices.getType().cast(); + auto indicesTy = cast(indices.getType()); if (indicesTy.getRank() != 1) return rewriter.notifyMatchFailure(op, "indices must be a vector"); - auto offsetsTy = offsets.getType().cast(); + auto offsetsTy = cast(offsets.getType()); if (offsetsTy.getRank() != 1) return rewriter.notifyMatchFailure(op, "offsets much be a vector"); @@ -471,10 +471,9 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern { Value input = adaptor.getSelf(); Value indices = adaptor.getIndex(); auto indicesTy = cast(indices.getType()); - RankedTensorType inputType = input.getType().cast(); - RankedTensorType resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType inputType = cast(input.getType()); + RankedTensorType resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); Type elementType = resultType.getElementType(); unsigned inputRank = inputType.getRank(); @@ -604,10 +603,9 @@ class ConvertAtenIndexTensorHackedTwinOp op, "aten.index.Tensor: index tensor must not be None"); } - RankedTensorType inputType = input.getType().cast(); - RankedTensorType resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType inputType = cast(input.getType()); + RankedTensorType resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); Type elementType = resultType.getElementType(); int inputRank = inputType.getRank(); int resultRank = resultType.getRank(); @@ -625,7 +623,7 @@ class ConvertAtenIndexTensorHackedTwinOp int maxRank = -1; for (auto indexTensor : indexTensors) { RankedTensorType indexTensorType = - indexTensor.getType().cast(); + cast(indexTensor.getType()); maxRank = std::max(maxRank, (int)indexTensorType.getRank()); } @@ -639,7 +637,7 @@ class ConvertAtenIndexTensorHackedTwinOp int64_t staticDimSize = -1; for (auto indexTensor : indexTensors) { RankedTensorType indexTensorType = - indexTensor.getType().cast(); + cast(indexTensor.getType()); int64_t indexTensorRank = indexTensorType.getRank(); if ((maxRank - indexTensorRank) > (i - startIndex)) continue; @@ -714,7 +712,7 @@ class ConvertAtenIndexTensorHackedTwinOp for (auto indexTensor : indexTensors) { RankedTensorType indexTensorType = - indexTensor.getType().cast(); + cast(indexTensor.getType()); auto indexTensorShape = makeShapeTorchCompatible(indexTensorType.getShape()); int rank = indexTensorShape.size(); @@ -828,7 +826,7 @@ class ConvertAtenUpsampleNearest2dOp Value input = adaptor.getSelf(); Type resultType = getTypeConverter()->convertType(op.getResult().getType()); - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); auto inputRank = inputType.getRank(); Type elementType = inputType.getElementType(); @@ -989,7 +987,7 @@ class ConvertAtenUpsampleNearest2dBackwardOp Value gradOutput = adaptor.getGradOutput(); Type resultType = getTypeConverter()->convertType(op.getResult().getType()); - auto gradOutputType = gradOutput.getType().cast(); + auto gradOutputType = cast(gradOutput.getType()); auto gradOutputRank = gradOutputType.getRank(); Type elementType = gradOutputType.getElementType(); diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index b7db0496f516..3f4e6ed66354 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -48,7 +48,7 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg, minSIValue = rewriter.create(loc, minSI, numBits); arg = torch_to_linalg::createElementwiseLinalgGeneric( rewriter, loc, ValueRange{arg}, - arg.getType().cast().getElementType(), + cast(arg.getType()).getElementType(), [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { Value result = rewriter.create(loc, payloadArgs[0], minSIValue); @@ -58,7 +58,7 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg, static Value transposeValue(Location loc, Value value, ArrayRef perms, PatternRewriter &rewriter) { - auto valueTy = value.getType().cast(); + auto valueTy = cast(value.getType()); auto inShape = valueTy.getShape(); llvm::SmallVector outShape; llvm::SmallVector dynDims; @@ -100,8 +100,8 @@ class ConvertAtenMmOp : public OpConversionPattern { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); - RankedTensorType lhsType = lhs.getType().cast(); - RankedTensorType rhsType = rhs.getType().cast(); + RankedTensorType lhsType = cast(lhs.getType()); + RankedTensorType rhsType = cast(rhs.getType()); if (lhsType.getRank() != 2 || rhsType.getRank() != 2) { return rewriter.notifyMatchFailure( @@ -109,9 +109,9 @@ class ConvertAtenMmOp : public OpConversionPattern { } ValueTensorType lhsTorchType = - op.getSelf().getType().cast(); + cast(op.getSelf().getType()); ValueTensorType rhsTorchType = - op.getMat2().getType().cast(); + cast(op.getMat2().getType()); Value lhsZeroPoint, rhsZeroPoint; getZeroPoint(op.getSelf(), lhsZeroPoint); @@ -148,7 +148,7 @@ class ConvertAtenMmOp : public OpConversionPattern { "mismatching contracting dimension for torch.aten.mm")); } - auto resultTy = op.getType().cast(); + auto resultTy = cast(op.getType()); auto resultDTy = resultTy.toBuiltinTensor().getElementType(); Type newResultType = getTypeConverter()->convertType(op.getType()); Type elementType = cast(newResultType).getElementType(); @@ -176,9 +176,9 @@ class ConvertAtenMmOp : public OpConversionPattern { // change uint8 quantization -> int8 quantization int64_t numBits = - lhsType.getElementType().cast().getWidth(); + cast(lhsType.getElementType()).getWidth(); signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits); - numBits = rhsType.getElementType().cast().getWidth(); + numBits = cast(rhsType.getElementType()).getWidth(); signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits); matmul = @@ -229,9 +229,9 @@ class ConvertAtenFlipOp : public OpConversionPattern { MLIRContext *context = op.getContext(); Value self = adaptor.getSelf(); auto selfRank = - adaptor.getSelf().getType().cast().getRank(); + cast(adaptor.getSelf().getType()).getRank(); Type elementType = - adaptor.getSelf().getType().cast().getElementType(); + cast(adaptor.getSelf().getType()).getElementType(); Value c1 = rewriter.create(loc, rewriter.getIndexAttr(1)); @@ -299,8 +299,8 @@ class ConvertAtenMatmulOp : public OpConversionPattern { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) { return failure(); } - auto lhsType = lhs.getType().cast(); - auto rhsType = rhs.getType().cast(); + auto lhsType = cast(lhs.getType()); + auto rhsType = cast(rhs.getType()); auto lhsTorchType = cast(op.getSelf().getType()); auto rhsTorchType = cast(op.getOther().getType()); @@ -348,9 +348,9 @@ class ConvertAtenMatmulOp : public OpConversionPattern { // change uint8 quantization -> int8 quantization int64_t numBits = - lhsType.getElementType().cast().getWidth(); + cast(lhsType.getElementType()).getWidth(); signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits); - numBits = rhsType.getElementType().cast().getWidth(); + numBits = cast(rhsType.getElementType()).getWidth(); signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits); // for quantized vec-vec, vec-mat, and mat-vec cases, lower to @@ -726,8 +726,8 @@ class ConvertAtenBmmOp : public OpConversionPattern { Location loc = op->getLoc(); Value lhs = adaptor.getSelf(); Value rhs = adaptor.getMat2(); - RankedTensorType lhsType = lhs.getType().cast(); - RankedTensorType rhsType = rhs.getType().cast(); + RankedTensorType lhsType = cast(lhs.getType()); + RankedTensorType rhsType = cast(rhs.getType()); Type newResultType = getTypeConverter()->convertType(op.getType()); Type resultElementType = cast(newResultType).getElementType(); @@ -794,7 +794,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { Value input = adaptor.getInput(); /* in form of N*C*H*W */ Value weight = adaptor.getWeight(); /* in form of F*C*H*W */ Value bias = adaptor.getBias(); - auto resultTy = op.getType().cast(); + auto resultTy = cast(op.getType()); Value inputZp, weightZp; if (auto make = op.getInput() @@ -826,7 +826,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { } if (inputZp && weightZp && !isa(bias.getType())) { - auto biasDTy = bias.getType().cast().getElementType(); + auto biasDTy = cast(bias.getType()).getElementType(); if (!biasDTy.isInteger(32)) { return rewriter.notifyMatchFailure( op, "quantized result ty should be i32 accumulator"); @@ -838,15 +838,15 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "unimplemented: only constant transposed supported"); - auto inputDTy = input.getType().cast().getElementType(); - auto weightDTy = weight.getType().cast().getElementType(); + auto inputDTy = cast(input.getType()).getElementType(); + auto weightDTy = cast(weight.getType()).getElementType(); auto resultDTy = resultTy.toBuiltinTensor().getElementType(); if (!isa(inputDTy) || !isa(weightDTy) || !isa(resultDTy)) return op.emitError("unimplemented: non-fp not-int type"); - size_t inRank = input.getType().cast().getRank(); + size_t inRank = cast(input.getType()).getRank(); size_t numSpatialDims = inRank - 2; if (numSpatialDims < 1 || numSpatialDims > 3) return rewriter.notifyMatchFailure( @@ -1067,11 +1067,11 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { rewriter.create(loc, c0, initTensor).getResult(0); } else { - auto biasType = bias.getType().cast(); + auto biasType = cast(bias.getType()); if (biasType.getRank() != 1) return rewriter.notifyMatchFailure(op, "expect bias to be rank 1"); - auto resultRank = initTensor.getType().cast().getRank(); + auto resultRank = cast(initTensor.getType()).getRank(); SmallVector indexingMaps = { // bias is used to initialize the channels - dimension 1 of output AffineMap::get(/*dimCount=*/resultRank, /*symbolCount=*/0, @@ -1228,9 +1228,9 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { // Special depthwise case auto inShape = makeShapeTorchCompatible( - input.getType().cast().getShape()); + cast(input.getType()).getShape()); auto weightShape = makeShapeTorchCompatible( - weight.getType().cast().getShape()); + cast(weight.getType()).getShape()); if (weightShape[0] != kUnknownSize && inShape[1] == groupSize && weightShape[0] % inShape[1] == 0 && weightShape[1] == 1) { // Collapse weight shape @@ -1264,7 +1264,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { // Grouped case, use the grouped conv linalg op auto expandGroups = [&](Value tensor, size_t dim) { - auto inType = tensor.getType().cast(); + auto inType = cast(tensor.getType()); auto inShape = makeShapeTorchCompatible(inType.getShape()); SmallVector outShape; @@ -1297,7 +1297,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { // expand F,C,H,W -> G,F/G,C,H,W auto expandWeight = [&](Value tensor) { - auto inType = tensor.getType().cast(); + auto inType = cast(tensor.getType()); auto inShape = makeShapeTorchCompatible(inType.getShape()); SmallVector outShape{ diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 42d166c5bf90..4157ef285888 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -80,7 +80,7 @@ computeOutputTensor(Operation *op, ConversionPatternRewriter &rewriter, SmallVectorImpl &dilationInts, SmallVectorImpl &kernelSizeIntValues, SmallVectorImpl &outTensorShape, Value initValue) { - Type elementType = self.getType().cast().getElementType(); + Type elementType = cast(self.getType()).getElementType(); Location loc = op->getLoc(); Value N = getDimOp(rewriter, loc, self, 0); @@ -116,7 +116,7 @@ static Value padInputTensor(Operation *op, ConversionPatternRewriter &rewriter, SmallVector lowPaddingIncludingNC = {0, 0}; SmallVector highPaddingIncludingNC = {0, 0}; - unsigned selfRank = self.getType().cast().getRank(); + unsigned selfRank = cast(self.getType()).getRank(); unsigned paddingIntsSize = paddingInts.size(); if (paddingIntsSize == 2 * (selfRank - 2)) { @@ -153,7 +153,7 @@ static LogicalResult createPoolingOp( SmallVectorImpl &dilationInts, Attribute initValueAttr, SmallVectorImpl &outTensorShape, Value &paddedInput, Value &result) { Location loc = op->getLoc(); - Type elementType = self.getType().cast().getElementType(); + Type elementType = cast(self.getType()).getElementType(); if (!isa(elementType) && !supportNonFPInput) return op->emitError("unimplemented: non-floating point type"); @@ -214,7 +214,7 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern { bool ceilMode) const { SmallVector outTensorShape; Value self = adaptor.getSelf(); - Type elementType = self.getType().cast().getElementType(); + Type elementType = cast(self.getType()).getElementType(); TypedAttr smallestFPValueAttr = rewriter.getFloatAttr( elementType, APFloat::getInf(cast(elementType).getFloatSemantics(), @@ -307,7 +307,7 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern { const TypeConverter *typeConverter = this->getTypeConverter(); Value self = adaptor.getSelf(); - int64_t selfRank = self.getType().cast().getRank(); + int64_t selfRank = cast(self.getType()).getRank(); if (selfRank != Dim + 2) return rewriter.notifyMatchFailure( @@ -326,7 +326,7 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern { strideInts, paddingInts))) return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); - Type elementType = self.getType().cast().getElementType(); + Type elementType = cast(self.getType()).getElementType(); if constexpr (Dim == 2) { SmallVector outTensorShape; @@ -389,7 +389,7 @@ class ConvertAtenMaxPool2dWithIndicesOp Location loc = op->getLoc(); const TypeConverter *typeConverter = getTypeConverter(); Value self = adaptor.getSelf(); - RankedTensorType selfType = self.getType().cast(); + RankedTensorType selfType = cast(self.getType()); Type elementType = selfType.getElementType(); RankedTensorType indicesRankedTensorType = getTypeConverter() @@ -552,7 +552,7 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { Value self = adaptor.getSelf(); Type inputElementType = - self.getType().cast().getElementType(); + cast(self.getType()).getElementType(); Type resultType = typeConverter->convertType(op.getType()); Type resultElementType = cast(resultType).getElementType(); @@ -592,10 +592,9 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { if constexpr (std::is_same()) { Value kHtimeskW = rewriter.create( loc, kernelSizeIntValues[0], kernelSizeIntValues[1]); - divisor = - op.getDivisorOverride().getType().template isa() - ? kHtimeskW - : adaptor.getDivisorOverride(); + divisor = isa(op.getDivisorOverride().getType()) + ? kHtimeskW + : adaptor.getDivisorOverride(); } else { divisor = kernelSizeIntValues[0]; } @@ -901,7 +900,7 @@ class ConvertAtenAdaptivePoolOp : public OpConversionPattern { const TypeConverter *typeConverter = this->getTypeConverter(); Value input = adaptor.getSelf(); - RankedTensorType inputType = input.getType().cast(); + RankedTensorType inputType = cast(input.getType()); const Type elementType = inputType.getElementType(); // get rank of input (same as rank of output) diff --git a/lib/Conversion/TorchToLinalg/Random.cpp b/lib/Conversion/TorchToLinalg/Random.cpp index 6519a272330e..3a0b81f5a10a 100644 --- a/lib/Conversion/TorchToLinalg/Random.cpp +++ b/lib/Conversion/TorchToLinalg/Random.cpp @@ -127,7 +127,7 @@ class ConvertAtenUniformOp : public OpConversionPattern { Value from = adaptor.getFrom(); Value to = adaptor.getTo(); Value generator = adaptor.getGenerator(); - RankedTensorType resultType = self.getType().cast(); + RankedTensorType resultType = cast(self.getType()); Type elemTy = resultType.getElementType(); Type f64Ty = rewriter.getF64Type(); diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index a5238c9b1211..ffb3350a0733 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -66,8 +66,7 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { cast(typec->convertType(op.getResult(0).getType())); auto idxResultType = cast(typec->convertType(op.getResult(1).getType())); - RankedTensorType inputType = - input.getType().template cast(); + RankedTensorType inputType = cast(input.getType()); Type idxElementType = getElementTypeOrSelf(typec->convertType(idxResultType)); if (!isa(idxElementType)) @@ -472,7 +471,7 @@ class ConvertReductionOp : public ConversionPattern { auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}}; typename T::Adaptor adaptor(operands); opInfo.tensorOperand = adaptor.getSelf(); - auto inputType = opInfo.tensorOperand.getType().cast(); + auto inputType = cast(opInfo.tensorOperand.getType()); if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&opInfo.keepDim))) return rewriter.notifyMatchFailure(op, @@ -480,8 +479,7 @@ class ConvertReductionOp : public ConversionPattern { SmallVector dimList; int64_t dim; - bool isNoneOrEmptyDimList = - op.getDim().getType().template isa(); + bool isNoneOrEmptyDimList = isa(op.getDim().getType()); if (matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList))) { // Fix negative dimensions, if any, before adding to the list. for (int64_t dim : dimList) { @@ -522,7 +520,7 @@ class ConvertReductionOp : public ConversionPattern { if (isa(op)) { opInfo.tensorOperand = operands[0]; - auto inputType = opInfo.tensorOperand.getType().cast(); + auto inputType = cast(opInfo.tensorOperand.getType()); // `AtenAny`, `AtenAll`, `AtenSumOp`, `AtenProdOp`, `AtenMaxOp`, and // `AtenMinOp` each reduce along all the dimensions of the input tensor. diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index 1a549cd5e399..add928392719 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -42,7 +42,7 @@ class ConvertAtenConstantPadNdOp return failure(); Location loc = op->getLoc(); Value self = adaptor.getSelf(); - auto type = self.getType().cast(); + auto type = cast(self.getType()); int64_t rank = type.getRank(); auto primList = op.getPad().getDefiningOp(); @@ -105,7 +105,7 @@ class ConvertAtenConstantPadNdOp convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType); Type padType = tensor::PadOp::inferResultType( - self.getType().cast(), staticLow, staticHigh); + cast(self.getType()), staticLow, staticHigh); Value paddedInput = rewriter.create( loc, padType, self, lowPad, highPad, castedValue); rewriter.replaceOpWithNewOp(op, newResultType, paddedInput); @@ -354,7 +354,7 @@ class ConvertConstantTensorAllocOp : public OpConversionPattern { // The pin_memory should be either `False` or `none`. bool pinMemory; - if (!op.getPinMemory().getType().template isa() && + if (!isa(op.getPinMemory().getType()) && (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) || pinMemory)) { return rewriter.notifyMatchFailure( @@ -376,7 +376,7 @@ class ConvertConstantTensorAllocOp : public OpConversionPattern { auto resultType = typeConverter->convertType(op.getType()) .template cast(); Type resultElementType; - if (op.getDtype().getType().template isa()) { + if (isa(op.getDtype().getType())) { resultElementType = resultType.getElementType(); } else { int64_t dtypeInt; @@ -423,7 +423,7 @@ class ConvertAtenEmptyMemoryFormatOp // The pin_memory should be either `False` or `none`. bool pinMemory; - if (!op.getPinMemory().getType().template isa() && + if (!isa(op.getPinMemory().getType()) && (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) || pinMemory)) return rewriter.notifyMatchFailure( @@ -480,7 +480,7 @@ class ConvertAtenEmptyMemoryFormatOp resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size)); auto resultType = - typeConverter->convertType(op.getType()).cast(); + cast(typeConverter->convertType(op.getType())); Type resultElementType; if (op.getDtype().getType().isa()) { resultElementType = getDefaultDtypeForTorchScalar( diff --git a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp index 921bf0a828f4..1f8b2f980a9c 100644 --- a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp +++ b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp @@ -38,7 +38,7 @@ class ConvertAtenSizeIntOp : public OpConversionPattern { Location loc = op->getLoc(); Value self = adaptor.getSelf(); Value dim = adaptor.getDim(); - auto type = self.getType().cast(); + auto type = cast(self.getType()); Value inputRank = rewriter.create( loc, rewriter.getI64IntegerAttr(type.getRank())); Value dimPositive = toPositiveDimDynamic(rewriter, loc, dim, inputRank); @@ -86,8 +86,7 @@ class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern { Value input = adaptor.getA(); SmallVector inputSizes = getTensorSizes(rewriter, loc, input); int64_t inputRank = inputSizes.size(); - Type inputDtype = - op.getA().getType().template cast().getDtype(); + Type inputDtype = cast(op.getA().getType()).getDtype(); // The `input` tensor must contain exactly one element, i.e., either the // `input` is a zero rank tensor or all the dimensions of the `input` tensor diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 3c5d6cfaee07..430541db74f8 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -34,7 +34,7 @@ using namespace mlir::torch::Torch; // Check if a ranked-tensor has the specified element type. template static bool hasElementType(Value tensor) { - auto tensorType = tensor.getType().cast(); + auto tensorType = cast(tensor.getType()); Type tensorElementType = tensorType.getElementType(); return isa(tensorElementType); } @@ -173,8 +173,7 @@ static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op, return nullptr; } - Type elementalType = - op.getSelf().getType().template cast().getDtype(); + Type elementalType = cast(op.getSelf().getType()).getDtype(); if constexpr (std::is_same()) { return createLessThan(b, loc, elementalType, lhs, rhs); } @@ -200,7 +199,7 @@ template static LogicalResult createTriangularMatrix(OpBuilder &b, Location loc, ValueRange payloadArgs, Operation *op, ArrayRef operands, Value &result) { - auto inputType = operands[0].getType().cast(); + auto inputType = cast(operands[0].getType()); uint64_t inputRank = inputType.getRank(); // Use the indices of the two innermost dimensions. @@ -405,7 +404,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return nullptr; } Type resultElementType = - bitwiseAndScalar.getType().cast().getDtype(); + cast(bitwiseAndScalar.getType()).getDtype(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype, /*srcOriginalDtype=*/std::nullopt, /*dstOriginalDtype=*/resultElementType); @@ -537,7 +536,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( if (auto relu = dyn_cast(op)) { Value zeroPoint = getZeroPoint(relu.getSelf()); Value arg = payloadArgs[0]; - auto intType = arg.getType().dyn_cast(); + auto intType = dyn_cast(arg.getType()); if (zeroPoint && !intType) { relu.emitError("unimplemented: non-integer quantized Relu."); return nullptr; @@ -739,9 +738,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto add = dyn_cast(op)) { AtenAddTensorOp::Adaptor adaptor(operands); - Type resultElementType = add.getType().cast().getDtype(); - Type dtype = converter->convertType(add.getType()) - .cast() + Type resultElementType = cast(add.getType()).getDtype(); + Type dtype = cast(converter->convertType(add.getType())) .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype, /*srcOriginalDtype=*/std::nullopt, @@ -762,10 +760,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto sub = dyn_cast(op)) { AtenSubTensorOp::Adaptor adaptor(operands); - Type dtype = converter->convertType(sub.getType()) - .cast() + Type dtype = cast(converter->convertType(sub.getType())) .getElementType(); - Type resultElementType = sub.getType().cast().getDtype(); + Type resultElementType = cast(sub.getType()).getDtype(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype, /*srcOriginalDtype=*/std::nullopt, /*dstOriginalDtype=*/resultElementType); @@ -785,9 +782,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } } if (auto subScalar = dyn_cast(op)) { - Type dtype = converter->convertType(subScalar.getType()) - .cast() - .getElementType(); + Type dtype = + cast(converter->convertType(subScalar.getType())) + .getElementType(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value other = convertScalarToDtype(b, loc, operands[1], dtype); Value alpha = convertScalarToDtype( @@ -805,11 +802,11 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return nullptr; } if (auto addScalar = dyn_cast(op)) { - Type dtype = converter->convertType(addScalar.getType()) - .cast() - .getElementType(); + Type dtype = + cast(converter->convertType(addScalar.getType())) + .getElementType(); Type resultElementType = - addScalar.getType().cast().getDtype(); + cast(addScalar.getType()).getDtype(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype, /*srcOriginalDtype=*/std::nullopt, /*dstOriginalDtype=*/resultElementType); @@ -832,8 +829,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto mul = dyn_cast(op)) { AtenMulTensorOp::Adaptor adaptor(operands); - Type dtype = converter->convertType(mul.getType()) - .cast() + Type dtype = cast(converter->convertType(mul.getType())) .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); @@ -846,8 +842,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } } if (auto atan2 = dyn_cast(op)) { - Type dtype = converter->convertType(atan2.getType()) - .cast() + Type dtype = cast(converter->convertType(atan2.getType())) .getElementType(); if (!isa(dtype)) { atan2.emitError("Atan2 requires floating point result type"); @@ -883,8 +878,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto div = dyn_cast(op)) { AtenDivTensorOp::Adaptor adaptor(operands); - Type dtype = converter->convertType(div.getType()) - .cast() + Type dtype = cast(converter->convertType(div.getType())) .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); @@ -907,7 +901,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( operands); } if (auto pow = dyn_cast(op)) { - Type dtype = pow.getType().cast().getDtype(); + Type dtype = cast(pow.getType()).getDtype(); if (!isa(dtype)) { pow.emitError("unimplemented: non-floating point dtype"); return nullptr; @@ -925,14 +919,13 @@ static Value createLinalgPayloadCalculationForElementwiseOp( pow.emitError("unimplemented: non-floating point dtype"); return nullptr; } - Type dtype = pow.getSelf().getType().cast().getDtype(); + Type dtype = cast(pow.getSelf().getType()).getDtype(); Value expPromoted = convertScalarToDtype(b, loc, operands[1], dtype); return b.create(loc, payloadArgs[0], expPromoted); } if (auto pow = dyn_cast(op)) { - Type dtype = converter->convertType(pow.getType()) - .cast() + Type dtype = cast(converter->convertType(pow.getType())) .getElementType(); if (!isa(dtype)) { pow.emitError("unimplemented: non-floating point dtype"); @@ -944,8 +937,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto imag = dyn_cast(op)) { - Type dtype = converter->convertType(imag.getType()) - .cast() + Type dtype = cast(converter->convertType(imag.getType())) .getElementType(); if (!isa(dtype)) { imag.emitError("unimplemented: non-floating point dtype"); @@ -956,8 +948,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto real = dyn_cast(op)) { - Type dtype = converter->convertType(real.getType()) - .cast() + Type dtype = cast(converter->convertType(real.getType())) .getElementType(); if (!isa(dtype)) { real.emitError("unimplemented: non-floating point dtype"); @@ -968,7 +959,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto gtScalar = dyn_cast(op)) { - Type dtype = gtScalar.getSelf().getType().cast().getDtype(); + Type dtype = cast(gtScalar.getSelf().getType()).getDtype(); // TODO: `gtTensor` and `gtScalar` share similar code and can be called from // one static function. @@ -998,7 +989,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto geScalar = dyn_cast(op)) { - Type dtype = geScalar.getSelf().getType().cast().getDtype(); + Type dtype = cast(geScalar.getSelf().getType()).getDtype(); // TODO: The `AtenGeScalarOp` and `AtenGtScalarOp` share a lot of code that // can be refactored. @@ -1028,7 +1019,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto eqScalar = dyn_cast(op)) { - Type dtype = eqScalar.getSelf().getType().cast().getDtype(); + Type dtype = cast(eqScalar.getSelf().getType()).getDtype(); Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); @@ -1044,7 +1035,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto neScalar = dyn_cast(op)) { - Type dtype = neScalar.getSelf().getType().cast().getDtype(); + Type dtype = cast(neScalar.getSelf().getType()).getDtype(); Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); @@ -1060,7 +1051,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto ltScalar = dyn_cast(op)) { - Type dtype = ltScalar.getSelf().getType().cast().getDtype(); + Type dtype = cast(ltScalar.getSelf().getType()).getDtype(); Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); @@ -1088,7 +1079,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto leScalar = dyn_cast(op)) { - Type dtype = leScalar.getSelf().getType().cast().getDtype(); + Type dtype = cast(leScalar.getSelf().getType()).getDtype(); Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); @@ -1116,9 +1107,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto whereSelf = dyn_cast(op)) { - Type dtype = converter->convertType(whereSelf.getType()) - .cast() - .getElementType(); + Type dtype = + cast(converter->convertType(whereSelf.getType())) + .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[2], dtype); return b.create(loc, payloadArgs[0], lhs, rhs); @@ -1141,7 +1132,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, start, weightedDelta); } if (auto minimum = dyn_cast(op)) { - Type dtype = minimum.getType().cast().getDtype(); + Type dtype = cast(minimum.getType()).getDtype(); Type elemTy = converter->convertType(minimum.getType()) .cast() .getElementType(); @@ -1151,7 +1142,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, pred, lhs, rhs); } if (auto maximum = dyn_cast(op)) { - Type dtype = maximum.getType().cast().getDtype(); + Type dtype = cast(maximum.getType()).getDtype(); Type elemTy = converter->convertType(maximum.getType()) .cast() .getElementType(); @@ -1170,15 +1161,14 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return nullptr; } - Type dtype = converter->convertType(clamp.getType()) - .cast() + Type dtype = cast(converter->convertType(clamp.getType())) .getElementType(); if (!isa(dtype)) { clamp.emitError("unimplement type for clamp"); return nullptr; } - Type dstOriginalDtype = clamp.getType().cast().getDtype(); + Type dstOriginalDtype = cast(clamp.getType()).getDtype(); bool isUnsigned = isa(dstOriginalDtype); if (auto intTy = dyn_cast(dstOriginalDtype)) { isUnsigned = intTy.isUnsigned(); @@ -1219,9 +1209,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( clampTensor.emitError("unimplemented: runtime optional type"); return nullptr; } - Type dtype = converter->convertType(clampTensor.getType()) - .cast() - .getElementType(); + Type dtype = + cast(converter->convertType(clampTensor.getType())) + .getElementType(); bool isMinNone = true; auto result = payloadArgs[0]; if (!min.getType().isa()) { @@ -1263,8 +1253,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return result; } if (auto rsub = dyn_cast(op)) { - Type dtype = converter->convertType(rsub.getType()) - .cast() + Type dtype = cast(converter->convertType(rsub.getType())) .getElementType(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value other = convertScalarToDtype(b, loc, operands[1], dtype); @@ -1283,9 +1272,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return nullptr; } if (auto mulScalar = dyn_cast(op)) { - Type dtype = converter->convertType(mulScalar.getType()) - .cast() - .getElementType(); + Type dtype = + cast(converter->convertType(mulScalar.getType())) + .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, operands[1], dtype); if (isa(dtype)) @@ -1297,9 +1286,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto atenToDtype = dyn_cast(op)) { Value input = payloadArgs[0]; - Type dtype = converter->convertType(atenToDtype.getType()) - .cast() - .getElementType(); + Type dtype = + cast(converter->convertType(atenToDtype.getType())) + .getElementType(); Type resultElementType; int64_t dtypeInt; if (!matchPattern(atenToDtype.getDtype(), m_TorchConstantInt(&dtypeInt))) { @@ -1320,9 +1309,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return result; } if (auto divScalar = dyn_cast(op)) { - Type dtype = converter->convertType(divScalar.getType()) - .cast() - .getElementType(); + Type dtype = + cast(converter->convertType(divScalar.getType())) + .getElementType(); if (!isa(dtype)) { divScalar.emitError("unimplemented: non-floating point dtype"); return nullptr; @@ -1395,9 +1384,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return result; } if (auto reciprocal = dyn_cast(op)) { - Type dtype = converter->convertType(reciprocal.getType()) - .cast() - .getElementType(); + Type dtype = + cast(converter->convertType(reciprocal.getType())) + .getElementType(); Value arg = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Type elementType = arg.getType(); // assert(element != 0) @@ -1416,9 +1405,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( // The approach used here is as follows: // result = self <= threshold ? value : self AtenThresholdOp::Adaptor adaptor(operands); - Type dtype = converter->convertType(thresholdOp.getType()) - .cast() - .getElementType(); + Type dtype = + cast(converter->convertType(thresholdOp.getType())) + .getElementType(); Value self = payloadArgs[0]; Value threshold = @@ -1438,8 +1427,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( // The approach used here is as follows: // result = self <= threshold ? 0 : grad AtenThresholdBackwardOp::Adaptor adaptor(operands); - Type dtype = converter->convertType(thresholdBackward.getType()) - .cast() + Type dtype = cast( + converter->convertType(thresholdBackward.getType())) .getElementType(); Value grad = convertScalarToDtype(b, loc, payloadArgs[0], dtype); @@ -1459,15 +1448,15 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto fillScalar = dyn_cast(op)) { AtenFillScalarOp::Adaptor adaptor(operands); - Type dtype = converter->convertType(fillScalar.getType()) - .cast() - .getElementType(); + Type dtype = + cast(converter->convertType(fillScalar.getType())) + .getElementType(); return convertScalarToDtype(b, loc, adaptor.getValue(), dtype); } if (auto maskedFillTensor = dyn_cast(op)) { AtenMaskedFillScalarOp::Adaptor adaptor(operands); - Type dtype = converter->convertType(maskedFillTensor.getType()) - .cast() + Type dtype = cast( + converter->convertType(maskedFillTensor.getType())) .getElementType(); Value input = payloadArgs[0]; @@ -1477,9 +1466,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto fillTensor = dyn_cast(op)) { AtenFillTensorOp::Adaptor adaptor(operands); - Type dtype = converter->convertType(fillTensor.getType()) - .cast() - .getElementType(); + Type dtype = + cast(converter->convertType(fillTensor.getType())) + .getElementType(); return convertScalarToDtype(b, loc, payloadArgs[1], dtype); } @@ -1519,7 +1508,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( auto value = payloadArgs[0]; auto valueTy = value.getType(); auto qtensor = op->getOperand(0); - auto qtensorTy = qtensor.getType().cast().getDtype(); + auto qtensorTy = cast(qtensor.getType()).getDtype(); Value zp, scale; if (auto makeQTensor = @@ -1744,8 +1733,8 @@ class ConvertAtenNllLossForwardOp Value ignoreIndex = adaptor.getIgnoreIndex(); Value ignoreIndexVal = castIntToIndex(rewriter, loc, ignoreIndex); - unsigned inputRank = input.getType().cast().getRank(); - unsigned targetRank = target.getType().cast().getRank(); + unsigned inputRank = cast(input.getType()).getRank(); + unsigned targetRank = cast(target.getType()).getRank(); // TODO: Add support for k-dim loss. if (inputRank > 2) { @@ -1931,11 +1920,11 @@ class ConvertAtenBatchNormOp : public OpConversionPattern { failed(checkNotNone(rewriter, op, runningVar))) return failure(); - auto inputType = input.getType().cast(); - auto weightType = weight.getType().cast(); - auto biasType = bias.getType().cast(); - auto runningMeanType = runningMean.getType().cast(); - auto runningVarType = runningVar.getType().cast(); + auto inputType = cast(input.getType()); + auto weightType = cast(weight.getType()); + auto biasType = cast(bias.getType()); + auto runningMeanType = cast(runningMean.getType()); + auto runningVarType = cast(runningVar.getType()); auto inputRank = inputType.getRank(); if (inputRank < 2) @@ -2032,9 +2021,9 @@ class ConvertAtenNllLossBackwardOp Value ignoreIndex = castIntToIndex(rewriter, loc, adaptor.getIgnoreIndex()); Value totalWeight = adaptor.getTotalWeight(); - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); int inputRank = inputType.getRank(); - auto gradOutputType = gradOutput.getType().cast(); + auto gradOutputType = cast(gradOutput.getType()); Type resultElementType = gradOutputType.getElementType(); int64_t reduction; @@ -2059,7 +2048,7 @@ class ConvertAtenNllLossBackwardOp createZeroInitTensor(rewriter, loc, outputSize, resultElementType); auto getAffineMapForSingleElementTensor = [&](Value tensor) { - auto tensorType = tensor.getType().cast(); + auto tensorType = cast(tensor.getType()); SmallVector affineExprs(tensorType.getRank(), rewriter.getAffineConstantExpr(0)); return AffineMap::get(inputRank, /*symbolCount=*/0, affineExprs, @@ -2188,12 +2177,12 @@ class ConvertPrimsSplitDimOp : public OpConversionPattern { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); - auto aRankedTensorType = adaptor.getA().getType().cast(); + auto aRankedTensorType = cast(adaptor.getA().getType()); const TypeConverter *typeConverter = getTypeConverter(); auto resultRankedTensorType = - typeConverter->convertType(op.getType()).cast(); + cast(typeConverter->convertType(op.getType())); // The dimension being split must be statically known. @@ -2233,11 +2222,11 @@ class ConvertPrimsCollapseOp : public OpConversionPattern { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); - auto aRankedTensorType = adaptor.getA().getType().cast(); + auto aRankedTensorType = cast(adaptor.getA().getType()); const TypeConverter *typeConverter = getTypeConverter(); auto resultRankedTensorType = - typeConverter->convertType(op.getType()).cast(); + cast(typeConverter->convertType(op.getType())); // Collapse range must be statically known. int64_t startInt; @@ -2328,7 +2317,7 @@ class ConvertLogitOp : public OpConversionPattern { return failure(); } - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); auto inputElementType = inputType.getElementType(); if (!isa(inputElementType)) { @@ -2433,8 +2422,8 @@ class ConvertDequantizePerChannel return failure(); } - auto operandDTy = operand.getType().cast().getDtype(); - auto zeropointDTy = zeropoint.getType().cast().getDtype(); + auto operandDTy = cast(operand.getType()).getDtype(); + auto zeropointDTy = cast(zeropoint.getType()).getDtype(); operand = converter->materializeTargetConversion( rewriter, loc, converter->convertType(operand.getType()), operand); scale = converter->materializeTargetConversion( @@ -2537,7 +2526,7 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { Value twoFloat = rewriter.create( loc, rewriter.getFloatAttr(floatType, 2.0)); Value input = adaptor.getInput(); - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); auto inputShape = inputType.getShape(); Value innerDim0a = rewriter.create(loc, input, 2); Value innerDim1a = rewriter.create(loc, input, 3); @@ -2558,7 +2547,7 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { Value innerDim1e = rewriter.create(loc, innerDim1d, twoFloat); Value grid = adaptor.getGrid(); - auto gridType = grid.getType().cast(); + auto gridType = cast(grid.getType()); auto gridShape = gridType.getShape(); auto gridRank = gridType.getRank(); SmallVector extractGridOffsets0(gridRank, zeroIndex); diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index f1749e026ead..c015ce563dd6 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -37,9 +37,8 @@ Value torch_to_linalg::getPaddedTensor( SmallVectorImpl &lowPaddingInts, SmallVectorImpl &highPaddingInts, Value pad) { Location loc = op->getLoc(); - Type rankedTensorType = - tensor::PadOp::inferResultType(input.getType().cast(), - lowPaddingInts, highPaddingInts); + Type rankedTensorType = tensor::PadOp::inferResultType( + cast(input.getType()), lowPaddingInts, highPaddingInts); SmallVector lowPaddings = getIndexIntsAsOpFoldResult(b, lowPaddingInts); SmallVector highPaddings = @@ -61,7 +60,7 @@ Value torch_to_linalg::getZeroPaddedTensor( Location loc = op->getLoc(); Value c0 = b.create( loc, - b.getZeroAttr(input.getType().cast().getElementType())); + b.getZeroAttr(cast(input.getType()).getElementType())); return getPaddedTensor(op, b, input, paddingInts, paddingInts, c0); } @@ -73,7 +72,7 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor( int unpaddedDims, Value pad) { assert(input.getType().isa() && "input must be RankedTensorType"); - unsigned int inRank = input.getType().cast().getRank(); + unsigned int inRank = cast(input.getType()).getRank(); Location loc = op->getLoc(); SmallVector inputDims = getTensorSizes(b, loc, input); @@ -86,7 +85,7 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor( pad < paddingIncludingUnchanged.end(); pad++) *pad = castIntToIndex(b, loc, *pad); - Type elementType = input.getType().cast().getElementType(); + Type elementType = cast(input.getType()).getElementType(); // TODO: audit possibility of sparsity on this tensor Type inputType = RankedTensorType::get(makeShapeLLVMCompatible(llvm::ArrayRef( @@ -158,7 +157,7 @@ Value torch_to_linalg::getOutputDimForConvTransposeOps( Value torch_to_linalg::createReductionLinalgGeneric( OpBuilder &b, Location loc, const ReductionOpInfo &opInfo, Value initElem, function_ref bodyBuild) { - auto inputType = opInfo.tensorOperand.getType().cast(); + auto inputType = cast(opInfo.tensorOperand.getType()); // Get the result shape by obtaining the size of each // dimension in the input tensor that is not getting reduced. @@ -237,7 +236,7 @@ Value torch_to_linalg::createElementwiseLinalgGeneric( SmallVector operandRanks; operandRanks.resize(tensorOperands.size()); llvm::transform(tensorOperands, operandRanks.begin(), [](Value tensor) { - return tensor.getType().dyn_cast().getRank(); + return dyn_cast(tensor.getType()).getRank(); }); auto resultRankIt = @@ -253,7 +252,7 @@ Value torch_to_linalg::createElementwiseLinalgGeneric( bool elideDynamicBroadcastCheck = isAssumingStrictSymbolicShapes(b); for (Value tensorOperand : tensorOperands) { SmallVector exprs; - auto type = tensorOperand.getType().cast(); + auto type = cast(tensorOperand.getType()); for (auto size : llvm::enumerate(makeShapeTorchCompatible(type.getShape()))) { // If the size is statically known to be 1, we don't want any @@ -327,7 +326,7 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( Operation *op, PatternRewriter &rewriter, Value input, SmallVector broadcastToShape, RankedTensorType broadcastType, Value &result, SmallVector useBroadcastToShape) { - RankedTensorType inputType = input.getType().cast(); + RankedTensorType inputType = cast(input.getType()); int64_t inputRank = inputType.getRank(); int64_t outputRank = broadcastToShape.size(); ArrayRef outputShape = broadcastType.getShape(); @@ -525,7 +524,7 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( Value torch_to_linalg::removeSizeInformation(OpBuilder &b, Location loc, Value tensor) { - auto tensorType = tensor.getType().cast(); + auto tensorType = cast(tensor.getType()); auto rank = tensorType.getRank(); SmallVector unknownSizes(rank, kUnknownSize); return b.create( diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 5cc6b0928898..d01f0dafa758 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -66,8 +66,8 @@ Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant, LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op, mlir::Value &self, mlir::Value &other, size_t dimSizeIndexBits) { - auto selfTy = self.getType().template dyn_cast(); - auto otherTy = other.getType().template dyn_cast(); + auto selfTy = dyn_cast(self.getType()); + auto otherTy = dyn_cast(other.getType()); auto selfRank = selfTy.getRank(); auto otherRank = otherTy.getRank(); if (selfRank == 0 || otherRank == 0) @@ -171,7 +171,7 @@ class ConvertAtenUnaryOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value self = adaptor.getSelf(); - auto selfType = self.getType().cast(); + auto selfType = cast(self.getType()); if (!selfType) { return op.emitError("only Tensor types supported in StableHLO"); } @@ -197,12 +197,12 @@ class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value self = adaptor.getSelf(); - auto selfTy = self.getType().cast(); + auto selfTy = cast(self.getType()); if (!selfTy) return op.emitError("only Tensor types supported in StableHLO"); - if (selfTy.getElementType().isa()) { + if (isa(selfTy.getElementType())) { rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( @@ -229,14 +229,14 @@ class ConvertAtenUnaryPromoteToFPOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value self = adaptor.getSelf(); - auto selfTy = self.getType().cast(); + auto selfTy = cast(self.getType()); if (!selfTy) return op.emitError("only Tensor types supported in StableHLO"); auto resultTy = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); - if (resultTy.getElementType().template isa()) { + if (isa(resultTy.getElementType())) { Value src = hlo::promoteType(rewriter, op.getLoc(), self, resultTy); rewriter.replaceOpWithNewOp(op, resultTy, src); return success(); @@ -304,8 +304,7 @@ class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern { LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto inputType = - adaptor.getA().getType().template dyn_cast(); + auto inputType = dyn_cast(adaptor.getA().getType()); if (!inputType) op.emitError("only Tensor types supported in StableHLO"); @@ -313,8 +312,7 @@ class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern { Value input = adaptor.getA(); SmallVector inputSizes = getTensorSizes(rewriter, loc, input); int64_t inputRank = inputSizes.size(); - Type inputDtype = - op.getA().getType().template cast().getDtype(); + Type inputDtype = cast(op.getA().getType()).getDtype(); Value constantOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); @@ -345,9 +343,9 @@ class ConvertAtenBinaryBroadcastOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.getSelf(); - auto lhsTy = lhs.getType().cast(); + auto lhsTy = cast(lhs.getType()); Value rhs = adaptor.getOther(); - auto rhsTy = rhs.getType().cast(); + auto rhsTy = cast(rhs.getType()); if (!lhsTy || !rhsTy) return op.emitError("only Tensor types supported"); @@ -378,9 +376,9 @@ class ConvertAtenAddSubOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.getSelf(); - RankedTensorType lhsType = lhs.getType().dyn_cast(); + RankedTensorType lhsType = dyn_cast(lhs.getType()); Value rhs = adaptor.getOther(); - RankedTensorType rhsType = rhs.getType().dyn_cast(); + RankedTensorType rhsType = dyn_cast(rhs.getType()); if (!lhsType) return op.emitError("only Tensor types supported in StableHLO"); @@ -433,9 +431,9 @@ class ConvertAtenMulDivOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.getSelf(); - auto lhsType = lhs.getType().dyn_cast(); + auto lhsType = dyn_cast(lhs.getType()); Value rhs = adaptor.getOther(); - TensorType rhsType = rhs.getType().dyn_cast(); + TensorType rhsType = dyn_cast(rhs.getType()); if (!lhsType) return op.emitError("only Tensor types supported in StableHLO"); @@ -527,8 +525,8 @@ class ConvertAtenCompareOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.getSelf(); Value rhs = adaptor.getOther(); - RankedTensorType lhsTy = lhs.getType().dyn_cast(); - RankedTensorType rhsTy = rhs.getType().dyn_cast(); + RankedTensorType lhsTy = dyn_cast(lhs.getType()); + RankedTensorType rhsTy = dyn_cast(rhs.getType()); if (!lhsTy) return op.emitError("only Tensor types supported in StableHLO"); @@ -616,8 +614,8 @@ class ConvertAtenLogicalBinaryOp : public OpConversionPattern { Value lhs = adaptor.getSelf(); Value rhs = adaptor.getOther(); - RankedTensorType lhsTy = lhs.getType().dyn_cast(); - RankedTensorType rhsTy = rhs.getType().dyn_cast(); + RankedTensorType lhsTy = dyn_cast(lhs.getType()); + RankedTensorType rhsTy = dyn_cast(rhs.getType()); if (!lhsTy) return op.emitError("lhs must be a ranked tensor type"); @@ -659,11 +657,10 @@ class ConvertAtenTransposeIntOp return rewriter.notifyMatchFailure(op, "dim1 must be constant"); } - auto inType = self.getType().cast(); + auto inType = cast(self.getType()); auto inputRank = inType.getRank(); - auto outType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + auto outType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); dim0 = toPositiveDim(dim0, inputRank); if (!isValidDim(dim0, inputRank)) { @@ -691,7 +688,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); auto outType = - getTypeConverter()->convertType(op.getType()).cast(); + cast(getTypeConverter()->convertType(op.getType())); rewriter.replaceOpWithNewOp(op, outType, self); return success(); } @@ -701,7 +698,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenSizeIntOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // Not a tensor type. - auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType) return op.emitError("only tensor types are currently supported"); @@ -739,7 +736,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value other = adaptor.getOther(); auto outType = - getTypeConverter()->convertType(op.getType()).cast(); + cast(getTypeConverter()->convertType(op.getType())); // promote self and other types self = hlo::promoteType(rewriter, op.getLoc(), self, outType); other = hlo::promoteType(rewriter, op.getLoc(), other, outType); @@ -764,10 +761,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenBroadcastToOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); - auto selfTy = self.getType().cast(); - auto outType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + auto selfTy = cast(self.getType()); + auto outType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); if (options.enableStaticShape && selfTy.hasStaticShape()) { Value bcastOp = hlo::promoteAndBroadcast(rewriter, self, outType); @@ -831,10 +827,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); // Not a ranked tensor type - auto inType = self.getType().dyn_cast(); - auto outType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + auto inType = dyn_cast(self.getType()); + auto outType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); if (!inType) return op.emitError("only ranked tensor types with static shapes are " "currently supported"); @@ -861,15 +856,14 @@ template <> LogicalResult ConvertAtenOp::matchAndRewrite( ValueTensorLiteralOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - RankedTensorType resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); // Tensors with integer types need to be converted to signless integer // element type. All tensors with element types other than integer can reuse // existing elements attribute. // TODO: what about unsigned integer? - if (auto elements = op.getValueAttr().dyn_cast()) { + if (auto elements = dyn_cast(op.getValueAttr())) { Type builtinTensorElemTy = resultType.getElementType(); unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth(); @@ -892,9 +886,8 @@ template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenTensorIntOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - RankedTensorType resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); Type outElementType = resultType.getElementType(); Value innerValue = adaptor.getT(); Value stablehloTensor = @@ -910,10 +903,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenReciprocalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); - auto inputTy = input.getType().cast(); + auto inputTy = cast(input.getType()); auto outTy = - getTypeConverter()->convertType(op.getType()).cast(); - if (!inputTy.getElementType().isa()) { + cast(getTypeConverter()->convertType(op.getType())); + if (!isa(inputTy.getElementType())) { return op.emitError("only floating-point datatype legalization supported " "for AtenReciprocalOp"); } @@ -929,9 +922,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenPowTensorScalarOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value lhs = adaptor.getSelf(); - auto lhsType = lhs.getType().dyn_cast(); + auto lhsType = dyn_cast(lhs.getType()); Value rhs = adaptor.getExponent(); - TensorType rhsType = rhs.getType().dyn_cast(); + TensorType rhsType = dyn_cast(rhs.getType()); if (!lhsType) return op.emitError("only Tensor types supported in StableHLO"); @@ -1002,9 +995,8 @@ template <> LogicalResult ConvertAtenOp::matchAndRewrite( PrimNumToTensorScalarOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - RankedTensorType outputType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType outputType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); auto outputElemType = outputType.getElementType(); Value stablehloTensor = hlo::scalarToStablehloTensor( rewriter, op, adaptor.getA(), outputElemType); @@ -1018,8 +1010,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenScalarImplicitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); - Type inputDtype = - op.getA().getType().template cast().getDtype(); + Type inputDtype = cast(op.getA().getType()).getDtype(); Type resultType = this->getTypeConverter()->convertType(op->getResult(0).getType()); auto result = rewriter.create(loc, adaptor.getA()); @@ -1037,7 +1028,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a tensor type. - auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType) return op.emitError("only tensor types are currently supported"); @@ -1055,7 +1046,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenReluOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value lhs = adaptor.getSelf(); - auto lhsTy = lhs.getType().cast(); + auto lhsTy = cast(lhs.getType()); auto lhsElemTy = lhsTy.getElementType(); if (!isa(lhsElemTy)) { @@ -1080,7 +1071,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); Value input = adaptor.getSelf(); - auto inputTy = input.getType().template dyn_cast(); + auto inputTy = dyn_cast(input.getType()); if (!inputTy) { return op.emitError("only ranked tensor type is supported."); } @@ -1103,11 +1094,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenLog2Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); - auto inputTy = input.getType().template dyn_cast(); + auto inputTy = dyn_cast(input.getType()); if (!inputTy) { return op.emitError("only ranked tensor type is supported."); } - auto outTy = getTypeConverter()->convertType(op.getType()).cast(); + auto outTy = cast(getTypeConverter()->convertType(op.getType())); input = hlo::promoteType(rewriter, op.getLoc(), input, outTy); auto two = getConstantLike(rewriter, op.getLoc(), 2.0, input); @@ -1124,12 +1115,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenLog10Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); - auto inputTy = input.getType().template dyn_cast(); + auto inputTy = dyn_cast(input.getType()); if (!inputTy) { return op.emitError("only ranked tensor type is supported."); } - auto outTy = getTypeConverter()->convertType(op.getType()).cast(); + auto outTy = cast(getTypeConverter()->convertType(op.getType())); input = hlo::promoteType(rewriter, op.getLoc(), input, outTy); auto ten = getConstantLike(rewriter, op.getLoc(), 10.0, input); @@ -1146,8 +1137,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenErfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); - auto inputType = input.getType().cast(); - if (!inputType.getElementType().isa()) { + auto inputType = cast(input.getType()); + if (!isa(inputType.getElementType())) { return rewriter.notifyMatchFailure(op, "only float tensor is supported"); } rewriter.replaceOpWithNewOp( @@ -1161,7 +1152,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenBatchNormOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getInput(); - auto inputTy = input.getType().cast(); + auto inputTy = cast(input.getType()); Value weight = adaptor.getWeight(); Value bias = adaptor.getBias(); Value runningMean = adaptor.getRunningMean(); @@ -1174,10 +1165,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // all of NC, NCL, NCHW, NCDHW's feature index is 1. int64_t feature_index = 1; - if (!inputTy.getElementType().template isa()) { + if (!isa(inputTy.getElementType())) { return op.emitError("only input tensor of float type is supported"); } - auto inputElemTy = inputTy.getElementType().cast(); + auto inputElemTy = cast(inputTy.getElementType()); Value channelDim = rewriter.create(op->getLoc(), input, feature_index); @@ -1220,20 +1211,20 @@ LogicalResult ConvertAtenOp::matchAndRewrite( inputTy.getElementType())); } - auto weightTy = weight.getType().cast(); - auto biasTy = bias.getType().cast(); - auto runningMeanTy = runningMean.getType().cast(); - auto runningVarTy = runningVar.getType().cast(); + auto weightTy = cast(weight.getType()); + auto biasTy = cast(bias.getType()); + auto runningMeanTy = cast(runningMean.getType()); + auto runningVarTy = cast(runningVar.getType()); if (weightTy.getRank() != 1 || biasTy.getRank() != 1 || runningMeanTy.getRank() != 1 || runningVarTy.getRank() != 1) { return rewriter.notifyMatchFailure( op, "expect weight, bias, running_mean and running_var to be rank 1"); } - if (!weightTy.getElementType().template isa() || - !biasTy.getElementType().template isa() || - !runningMeanTy.getElementType().template isa() || - !runningVarTy.getElementType().template isa()) { + if (!isa(weightTy.getElementType()) || + !isa(biasTy.getElementType()) || + !isa(runningMeanTy.getElementType()) || + !isa(runningVarTy.getElementType())) { return op.emitError("only float weight/bias/runningMean/runningVar tensor " "of float type is supported"); } @@ -1261,8 +1252,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // supported mixed types, like input type is fp16 and weight type is fp32. if (inputTy.getElementType() != weightTy.getElementType()) { RankedTensorType convertedType = inputTy; - if (weightTy.getElementType().cast().getWidth() > - inputTy.getElementType().cast().getWidth()) { + if (cast(weightTy.getElementType()).getWidth() > + cast(inputTy.getElementType()).getWidth()) { convertedType = RankedTensorType::get(inputTy.getShape(), weightTy.getElementType()); } @@ -1302,8 +1293,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // supported mixed types, like input type is fp16 and weight type is fp32. if (inputTy.getElementType() != weightTy.getElementType()) { RankedTensorType convertedType = inputTy; - if (weightTy.getElementType().cast().getWidth() > - inputTy.getElementType().cast().getWidth()) { + if (cast(weightTy.getElementType()).getWidth() > + cast(inputTy.getElementType()).getWidth()) { convertedType = RankedTensorType::get(inputTy.getShape(), weightTy.getElementType()); } @@ -1340,7 +1331,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenNativeLayerNormOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getInput(); - auto inputTy = input.getType().cast(); + auto inputTy = cast(input.getType()); auto inputShape = inputTy.getShape(); auto inputRank = inputTy.getRank(); Value weight = adaptor.getWeight(); @@ -1365,12 +1356,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( failed(checkNotNone(rewriter, op, bias))) { return op->emitError("none weight or bias is unsupported"); } - auto weightTy = weight.getType().cast(); - auto biasTy = bias.getType().cast(); + auto weightTy = cast(weight.getType()); + auto biasTy = cast(bias.getType()); - if (!inputTy.getElementType().isa() || - !biasTy.getElementType().isa() || - !weightTy.getElementType().isa()) { + if (!isa(inputTy.getElementType()) || + !isa(biasTy.getElementType()) || + !isa(weightTy.getElementType())) { return op->emitError("currently only float data type are supported"); } int64_t normalizedShapeRank = normalizedShape.size(); @@ -1423,7 +1414,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( SmallVector oneConstVec( numFeatureDimSize, APFloat( - inputTy.getElementType().cast().getFloatSemantics(), + cast(inputTy.getElementType()).getFloatSemantics(), 1)); auto oneOrZeroConstType = RankedTensorType::get({numFeatureDimSize}, inputTy.getElementType()); @@ -1443,9 +1434,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Reshape back auto outputTy = - getTypeConverter()->convertType(op.getType(0)).cast(); + cast(getTypeConverter()->convertType(op.getType(0))); auto outputMeanOrVarTy = - getTypeConverter()->convertType(op.getType(1)).cast(); + cast(getTypeConverter()->convertType(op.getType(1))); auto output = rewriter.create( op->getLoc(), outputTy, batchNormTrainingResult.getResult(0), @@ -1482,7 +1473,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenCatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto outType = - getTypeConverter()->convertType(op.getType()).cast(); + cast(getTypeConverter()->convertType(op.getType())); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { return rewriter.notifyMatchFailure(op, @@ -1516,7 +1507,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenNumelOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto self = adaptor.getSelf(); - auto selfTy = self.getType().dyn_cast(); + auto selfTy = dyn_cast(self.getType()); size_t rank = selfTy.getRank(); Type intType = rewriter.getIntegerType(options.dimSizeIndexBits); @@ -1544,7 +1535,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenClampOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); auto inputElemType = inputType.getElementType(); Value minValue = adaptor.getMin(); Value maxValue = adaptor.getMax(); @@ -1716,7 +1707,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Location loc = op.getLoc(); Value input = adaptor.getSelf(); auto outType = - this->getTypeConverter()->convertType(op.getType()).cast(); + cast(this->getTypeConverter()->convertType(op.getType())); if (!outType) { return op.emitError("only tensor type is supported"); } @@ -1764,15 +1755,15 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenPowTensorTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value lhs = adaptor.getSelf(); - auto lhsTy = lhs.getType().cast(); + auto lhsTy = cast(lhs.getType()); Value rhs = adaptor.getExponent(); - auto rhsTy = rhs.getType().cast(); + auto rhsTy = cast(rhs.getType()); if (!lhsTy || !rhsTy) return op.emitError("only Tensor types supported"); auto outTy = - this->getTypeConverter()->convertType(op.getType()).cast(); + cast(this->getTypeConverter()->convertType(op.getType())); lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy); @@ -1790,12 +1781,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value generator = adaptor.getGenerator(); Location loc = op.getLoc(); - if (!generator.getType().isa()) + if (!isa(generator.getType())) return rewriter.notifyMatchFailure( op, "The generator has to be None because only global default " "generator is supported"); - auto elements = self.getType().cast().getShape(); + auto elements = cast(self.getType()).getShape(); if (llvm::any_of(elements, [](int64_t dim) { return dim == ShapedType::kDynamic; })) return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD"); @@ -1824,14 +1815,14 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // The pin_memory should be either `False` or `none`. bool pinMemory; - if (!op.getPinMemory().getType().template isa() && + if (!isa(op.getPinMemory().getType()) && (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) || pinMemory)) return rewriter.notifyMatchFailure( op, "unimplemented: pin_memory must be either None or false"); // Only `none`, `contiguous` and `preserve` memory_format is supported. - if (!op.getMemoryFormat().getType().isa()) { + if (!isa(op.getMemoryFormat().getType())) { int64_t memoryFormat; if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat))) return rewriter.notifyMatchFailure( @@ -1844,7 +1835,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( "memory_format is supported"); } - if (!op.getDevice().getType().isa()) { + if (!isa(op.getDevice().getType())) { std::string device; if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device))) return rewriter.notifyMatchFailure( @@ -1853,7 +1844,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // TODO: Add support for non-strided layout. // torch.layout is by default strided i.e. 0. - if (!op.getLayout().getType().isa()) { + if (!isa(op.getLayout().getType())) { int64_t tensorLayout; if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout))) return rewriter.notifyMatchFailure( @@ -1876,9 +1867,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size)); auto resultType = - typeConverter->convertType(op.getType()).cast(); + cast(typeConverter->convertType(op.getType())); Type resultElementType; - if (op.getDtype().getType().isa()) { + if (isa(op.getDtype().getType())) { resultElementType = resultType.getElementType(); } else { int64_t dtypeInt; @@ -1931,7 +1922,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenFillScalarOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto outType = - getTypeConverter()->convertType(op.getType()).cast(); + cast(getTypeConverter()->convertType(op.getType())); auto dtype = outType.getElementType(); Value scalarTensor = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getValue(), dtype); @@ -1951,7 +1942,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); auto outType = - getTypeConverter()->convertType(op.getType()).cast(); + cast(getTypeConverter()->convertType(op.getType())); SmallVector dims; if (!matchPattern(op.getDims(), m_TorchListOfConstantInts(dims))) { diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index ac1c8bacf9a8..00c022cc1067 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -64,7 +64,7 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, loc, rewriter.getIntegerAttr(intType, 1)); // sliceSizes - auto inputRankTy = input.getType().dyn_cast(); + auto inputRankTy = dyn_cast(input.getType()); auto inputRank = inputRankTy.getRank(); SmallVector sliceSizes; sliceSizes.reserve(inputRank); @@ -85,7 +85,7 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, for (int64_t r = 0; r < axis; ++r) { offsetDims.push_back(r); } - auto indicesRankTy = indices.getType().dyn_cast(); + auto indicesRankTy = dyn_cast(indices.getType()); auto indicesRank = indicesRankTy.getRank(); for (int64_t r = axis + 1; r < inputRank; ++r) { offsetDims.push_back(r + indicesRank - 1); @@ -132,8 +132,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, SmallVector &strides) { Location loc = op.getLoc(); auto input = adaptor.getSelf(); - RankedTensorType inputType = - input.getType().template cast(); + RankedTensorType inputType = cast(input.getType()); Value zero = rewriter.create(loc, 0); Value one = rewriter.create(loc, 1); @@ -161,7 +160,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, int64_t step; if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) { - if (!op.getStep().getType().template isa()) + if (!isa(op.getStep().getType())) return op->emitError("unimplemented: step is not constant"); step = 1; } @@ -225,7 +224,7 @@ FailureOr broadcastAndConcatIndices(Operation *op, // concat index tensor into to indices tensor for concat for (size_t i = 0; i < indexTensors.size(); i++) { auto indexTensor = indexTensors[i]; - auto indexTensorType = indexTensor.getType().cast(); + auto indexTensorType = cast(indexTensor.getType()); for (int64_t size : makeShapeTorchCompatible(indexTensorType.getShape())) { if (size == kUnknownSize) return failure(); @@ -249,7 +248,7 @@ FailureOr broadcastAndConcatIndices(Operation *op, SmallVector broadcastedIndices; Type indexElemTy = - indexTensors[0].getType().cast().getElementType(); + cast(indexTensors[0].getType()).getElementType(); RankedTensorType bcastIndexType = RankedTensorType::get(indicesShape, indexElemTy); for (auto indexTensor : indexTensors) { @@ -290,7 +289,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenEmbeddingOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto weight = adaptor.getWeight(); - auto weightTy = weight.getType().cast(); + auto weightTy = cast(weight.getType()); if (!weightTy) return op.emitError("only ranked tensor types are supported"); @@ -332,17 +331,17 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value indices = adaptor.getIndices(); Value offsets = adaptor.getOffsets(); - auto weightTy = weight.getType().cast(); + auto weightTy = cast(weight.getType()); if (weightTy && weightTy.hasStaticShape() && weightTy.getRank() != 2) return rewriter.notifyMatchFailure( op, "weight must be rank 2 tensor with static shapes"); - auto indicesTy = indices.getType().cast(); + auto indicesTy = cast(indices.getType()); if (indicesTy && indicesTy.hasStaticShape() && indicesTy.getRank() != 1) return rewriter.notifyMatchFailure( op, "indices must be a vector with static shapes"); - auto offsetsTy = offsets.getType().cast(); + auto offsetsTy = cast(offsets.getType()); if (offsetsTy && offsetsTy.getRank() != 1 && offsetsTy.hasStaticShape() && offsetsTy.getShape()[0] == 1) return rewriter.notifyMatchFailure( @@ -485,7 +484,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenIndexSelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto self = adaptor.getSelf(); - auto selfTy = self.getType().cast(); + auto selfTy = cast(self.getType()); if (!selfTy) return op.emitError("only ranked tensor types are supported"); int64_t dim; @@ -514,8 +513,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Location loc = op->getLoc(); Value input = adaptor.getSelf(); Value index = adaptor.getIndex(); - auto inputType = input.getType().cast(); - auto indexType = index.getType().cast(); + auto inputType = cast(input.getType()); + auto indexType = cast(index.getType()); auto indexElemType = indexType.getElementType(); if (indexType.getRank() != inputType.getRank()) { @@ -623,7 +622,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } Value src = adaptor.getSrc(); - auto srcType = src.getType().cast(); + auto srcType = cast(src.getType()); int64_t srcRank = srcType.getRank(); SmallVector srcAbstractSizes(srcRank, kUnknownSize); auto abstractSrcType = RankedTensorType::get( @@ -651,9 +650,9 @@ class ConvertAtenScatterOp : public ConvertAtenOp { Value input = adaptor.getSelf(); Value index = adaptor.getIndex(); Value src = adaptor.getSrc(); - auto inputType = input.getType().cast(); - auto indexType = index.getType().cast(); - auto srcType = src.getType().cast(); + auto inputType = cast(input.getType()); + auto indexType = cast(index.getType()); + auto srcType = cast(src.getType()); auto indexElemType = indexType.getElementType(); if (indexType.getRank() != inputType.getRank() || @@ -789,9 +788,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { Location loc = op->getLoc(); Value input = adaptor.getSelf(); - auto inputTensorType = input.getType().cast(); + auto inputTensorType = cast(input.getType()); auto outType = - getTypeConverter()->convertType(op.getType()).cast(); + cast(getTypeConverter()->convertType(op.getType())); auto outShape = outType.getShape(); Value indexList = op.getIndices(); SmallVector indicesTorchType; @@ -857,10 +856,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value input = adaptor.getSelf(); Value values = adaptor.getValues(); auto outType = - getTypeConverter()->convertType(op.getType()).cast(); - auto inputType = input.getType().cast(); + cast(getTypeConverter()->convertType(op.getType())); + auto inputType = cast(input.getType()); int64_t inputRank = inputType.getRank(); - auto valuesType = values.getType().cast(); + auto valuesType = cast(values.getType()); auto valuesShape = valuesType.getShape(); bool accumulate; if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate))) { diff --git a/lib/Conversion/TorchToStablehlo/Linear.cpp b/lib/Conversion/TorchToStablehlo/Linear.cpp index f95184833841..70028cd2df49 100644 --- a/lib/Conversion/TorchToStablehlo/Linear.cpp +++ b/lib/Conversion/TorchToStablehlo/Linear.cpp @@ -32,7 +32,7 @@ namespace { Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor, ArrayRef shape, ArrayRef dimSizes, ArrayRef broadcastDims) { - auto tensorTy = tensor.getType().dyn_cast(); + auto tensorTy = dyn_cast(tensor.getType()); auto loc = op->getLoc(); Value stablehloShape = rewriter.create(loc, dimSizes); @@ -48,7 +48,7 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor, Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input, ArrayRef inpTransDims) { - auto inputTy = input.getType().dyn_cast(); + auto inputTy = dyn_cast(input.getType()); auto rank = inputTy.getRank(); auto transDims = hlo::toPositiveDims(inpTransDims, rank); auto inpShape = inputTy.getShape(); @@ -70,8 +70,8 @@ RankedTensorType castContractingDim(PatternRewriter &rewriter, Operation *op, int64_t lhsResultDim, int64_t rhsResultDim, int64_t lhsContractingDim, int64_t rhsContractingDim) { - auto lhsTy = lhs.getType().dyn_cast(); - auto rhsTy = rhs.getType().dyn_cast(); + auto lhsTy = dyn_cast(lhs.getType()); + auto rhsTy = dyn_cast(rhs.getType()); auto oldLhsShape = lhsTy.getShape(); auto oldRhsShape = rhsTy.getShape(); @@ -129,8 +129,8 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs, size_t dimSizeIndexBits) { Value lhs = inpLhs; Value rhs = inpRhs; - auto lhsRankTy = inpLhs.getType().dyn_cast(); - auto rhsRankTy = inpRhs.getType().dyn_cast(); + auto lhsRankTy = dyn_cast(inpLhs.getType()); + auto rhsRankTy = dyn_cast(inpRhs.getType()); auto lhsRank = lhsRankTy.getRank(); auto rhsRank = rhsRankTy.getRank(); @@ -177,8 +177,8 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs, return; } - lhsShape = lhs.getType().cast().getShape(); - rhsShape = rhs.getType().cast().getShape(); + lhsShape = cast(lhs.getType()).getShape(); + rhsShape = cast(rhs.getType()).getShape(); // check shape compatibility, check if we should broadcast // first, we should got a new batch shape. Check from (0, nBatchDims) @@ -266,8 +266,8 @@ class ConvertAtenMatmulBaseOp : public ConvertAtenOp { LogicalResult performMatmul(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Value &lhs, Value &rhs, Value &output) const { - auto lhsTy = lhs.getType().cast(); - auto rhsTy = rhs.getType().cast(); + auto lhsTy = cast(lhs.getType()); + auto rhsTy = cast(rhs.getType()); auto lhsRank = lhsTy.getRank(); auto rhsRank = rhsTy.getRank(); @@ -370,10 +370,10 @@ class ConvertAtenMatMulOp : public ConvertAtenMatmulBaseOp { ConversionPatternRewriter &rewriter, Value &lhs, Value &rhs) const override { lhs = adaptor.getSelf(); - auto lhsTy = lhs.getType().cast(); + auto lhsTy = cast(lhs.getType()); rhs = adaptor.getOther(); - auto rhsTy = rhs.getType().cast(); + auto rhsTy = cast(rhs.getType()); if (!lhsTy || !rhsTy) return op.emitError( @@ -393,10 +393,10 @@ class ConvertAtenMmOp : public ConvertAtenMatmulBaseOp { ConversionPatternRewriter &rewriter, Value &lhs, Value &rhs) const override { lhs = adaptor.getSelf(); - auto lhsTy = lhs.getType().cast(); + auto lhsTy = cast(lhs.getType()); rhs = adaptor.getMat2(); - auto rhsTy = rhs.getType().cast(); + auto rhsTy = cast(rhs.getType()); if (!lhsTy || !rhsTy) return op.emitError( @@ -429,10 +429,10 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { ConversionPatternRewriter &rewriter, Value &lhs, Value &rhs) const override { lhs = adaptor.getInput(); - auto lhsTy = lhs.getType().cast(); + auto lhsTy = cast(lhs.getType()); rhs = adaptor.getWeight(); - auto rhsTy = rhs.getType().cast(); + auto rhsTy = cast(rhs.getType()); if (!lhsTy || !rhsTy) return op.emitError( @@ -464,16 +464,15 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { auto biasTy = bias.getType(); // StableHLO does not mandate that elementwise op tensors need to be ranked. - if (!biasTy.template isa() && - !biasTy.template isa()) + if (!isa(biasTy) && !isa(biasTy)) return op.emitError("only ranked tensor types are supported in StableHLO " "matmul for bias tensor"); // weight.T rhs = getPermutedTensor(rewriter, op, rhs, {1, 0}); - auto lhsTy = lhs.getType().cast(); - auto rhsTy = rhs.getType().cast(); + auto lhsTy = cast(lhs.getType()); + auto rhsTy = cast(rhs.getType()); auto leadingRank = std::max(lhsTy.getRank() - rhsTy.getRank(), rhsTy.getRank() - lhsTy.getRank()); @@ -503,7 +502,7 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr); Value matmulPlusBias = matmulOutput; - if (!biasTy.template isa()) { + if (!isa(biasTy)) { // Bias addition broadcasts to the matmul output shape. matmulPlusBias = rewriter .create( @@ -525,7 +524,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { Value reshapeConvWeight(PatternRewriter &rewriter, Operation *op, Value weight, int64_t groups) const { - auto weightTy = weight.getType().cast(); + auto weightTy = cast(weight.getType()); auto weightElemTy = weightTy.getElementType(); auto rank = weightTy.getRank(); const auto &options = getOptions(); @@ -588,8 +587,8 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { ArrayRef dilation, ArrayRef outputPadding, int64_t groups) const { - auto inputTy = input.getType().cast(); - auto weightTy = weight.getType().cast(); + auto inputTy = cast(input.getType()); + auto weightTy = cast(weight.getType()); auto weightShape = weightTy.getShape(); auto nDims = inputTy.getRank(); @@ -727,11 +726,11 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { Value weight = adaptor.getWeight(); // The input shape is [N, C, H, W] - auto inputTy = input.getType().template cast(); + auto inputTy = cast(input.getType()); // The weight shape is [OC, (IC//G), KH, KW] // If transposed is set to true, // the weight shape changes to [IC, (OC//G), KH, KW] - auto weightTy = weight.getType().template cast(); + auto weightTy = cast(weight.getType()); auto outTy = getTypeConverter() ->convertType(op.getType()) .template cast(); @@ -819,11 +818,11 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { } // Handle bias - if (!bias.getType().cast()) { + if (!cast(bias.getType())) { return op.emitError("bias provided but not a ranked tensor"); } - auto biasTy = bias.getType().cast(); + auto biasTy = cast(bias.getType()); if (!biasTy.getElementType().isIntOrFloat()) { return op.emitError("only floating-point or integer datatype " "legalization for bias supported"); diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index b8a5321306bb..132410a2a358 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -81,12 +81,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenMaxPool2dOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); - auto inputTy = input.getType().cast(); + auto inputTy = cast(input.getType()); auto inputElemTy = inputTy.getElementType(); auto inputRank = inputTy.getRank(); auto outTy = - getTypeConverter()->convertType(op.getType()).cast(); + cast(getTypeConverter()->convertType(op.getType())); if (inputRank <= 2) { return op.emitError( @@ -176,14 +176,14 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); - auto inputTy = input.getType().cast(); + auto inputTy = cast(input.getType()); auto inputElemTy = inputTy.getElementType(); auto inputShape = inputTy.getShape(); auto inputRank = inputTy.getRank(); auto outValTy = - getTypeConverter()->convertType(op.getType(0)).cast(); + cast(getTypeConverter()->convertType(op.getType(0))); auto outIdxTy = - getTypeConverter()->convertType(op.getType(1)).cast(); + cast(getTypeConverter()->convertType(op.getType(1))); if (inputRank <= 2) { return op.emitError( @@ -366,7 +366,7 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value input = adaptor.getSelf(); - RankedTensorType inputTy = input.getType().cast(); + RankedTensorType inputTy = cast(input.getType()); Type inputElemTy = inputTy.getElementType(); int64_t inputRank = inputTy.getRank(); RankedTensorType outTy = ConvertAtenOp::getTypeConverter() @@ -539,11 +539,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenCumsumOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); - auto inputTy = input.getType().cast(); + auto inputTy = cast(input.getType()); auto outTy = - getTypeConverter()->convertType(op.getType()).cast(); + cast(getTypeConverter()->convertType(op.getType())); input = hlo::promoteType(rewriter, op.getLoc(), input, outTy); - inputTy = input.getType().cast(); + inputTy = cast(input.getType()); auto inputElemTy = inputTy.getElementType(); auto inputRank = inputTy.getRank(); auto inputShape = inputTy.getShape(); diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index fee5cc01e4ae..81a1a1f564d1 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -126,7 +126,7 @@ static std::optional getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, ArrayRef inputShapeVec, int64_t dim, size_t dimSizeIndexBits) { - auto inputTy = input.getType().template cast(); + auto inputTy = cast(input.getType()); if (!inputTy) { return std::nullopt; } @@ -249,7 +249,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( AtenArgmaxOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); - auto inputTy = input.getType().template cast(); + auto inputTy = cast(input.getType()); if (!inputTy) { return rewriter.notifyMatchFailure( op, "only Tensor types supported in StableHLO"); @@ -321,7 +321,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( AtenMaxDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); - auto inputTy = input.getType().template dyn_cast(); + auto inputTy = dyn_cast(input.getType()); if (!inputTy) { return rewriter.notifyMatchFailure( op, "only Tensor types supported in StableHLO"); @@ -410,7 +410,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( AtenSumOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); - auto inputTy = input.getType().dyn_cast(); + auto inputTy = dyn_cast(input.getType()); auto outTy = getTypeConverter() ->convertType(op.getType()) .template dyn_cast(); @@ -423,7 +423,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( auto dstElemTy = outTy.getElementType(); input = rewriter.create(op->getLoc(), input, dstElemTy); - inputTy = input.getType().dyn_cast(); + inputTy = dyn_cast(input.getType()); } auto inputElemTy = inputTy.getElementType(); if (!inputElemTy.isIntOrFloat()) { @@ -626,7 +626,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( AtenProdOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); - auto inputTy = input.getType().dyn_cast(); + auto inputTy = dyn_cast(input.getType()); auto outTy = getTypeConverter() ->convertType(op.getType()) .template dyn_cast(); @@ -639,7 +639,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( auto dstElemTy = outTy.getElementType(); input = rewriter.create(op->getLoc(), input, dstElemTy); - inputTy = input.getType().dyn_cast(); + inputTy = dyn_cast(input.getType()); } auto inputElemTy = inputTy.getElementType(); if (!inputElemTy.isIntOrFloat()) { @@ -699,7 +699,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( AtenMaxOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); - auto inputTy = input.getType().dyn_cast(); + auto inputTy = dyn_cast(input.getType()); if (!inputTy) { return rewriter.notifyMatchFailure( op, "only Tensor types supported in StableHLO"); @@ -762,7 +762,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( AtenMinOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); - auto inputTy = input.getType().dyn_cast(); + auto inputTy = dyn_cast(input.getType()); if (!inputTy) { return rewriter.notifyMatchFailure( op, "only Tensor types supported in StableHLO"); @@ -825,7 +825,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( AtenSumDimIntListOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); - auto inputTy = input.getType().dyn_cast(); + auto inputTy = dyn_cast(input.getType()); auto outTy = getTypeConverter() ->convertType(op.getType()) .template dyn_cast(); @@ -838,7 +838,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( auto dstElemTy = outTy.getElementType(); input = rewriter.create(op->getLoc(), input, dstElemTy); - inputTy = input.getType().dyn_cast(); + inputTy = dyn_cast(input.getType()); } auto inputElemTy = inputTy.getElementType(); if (!inputElemTy.isIntOrFloat()) { @@ -958,7 +958,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( const TorchToStablehloOptions &options = getOptions(); Value input = adaptor.getSelf(); - auto inputType = input.getType().dyn_cast(); + auto inputType = dyn_cast(input.getType()); if (!inputType) { return op.emitError( "only ranked tensor input supported in AtenFrobeniusNormDimOp"); @@ -1070,7 +1070,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( const TorchToStablehloOptions &options = getOptions(); Value input = adaptor.getSelf(); - auto inputType = input.getType().dyn_cast(); + auto inputType = dyn_cast(input.getType()); if (!inputType) { return op.emitError( "only ranked tensor input supported in AtenLinalgVectorNormOp"); @@ -1078,7 +1078,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( int64_t inputRank = inputType.getRank(); auto outType = - getTypeConverter()->convertType(op.getType()).cast(); + cast(getTypeConverter()->convertType(op.getType())); auto outElemType = outType.getElementType(); if (!isa(outElemType)) { return op.emitError("only float dtype allowed in AtenLinalgVectorNormOp"); diff --git a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp index 551f79c47288..5db6ee339b09 100644 --- a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp @@ -144,7 +144,7 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter, Value promoteType(PatternRewriter &rewriter, Location loc, Value input, TensorType outType) { - TensorType in_type = input.getType().cast(); + TensorType in_type = cast(input.getType()); if (in_type.getElementType() != outType.getElementType()) { TensorType promotedType = @@ -162,7 +162,7 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, // dimension, the dimension sizes must either be equal, one of them is 1, or // one of them does not exist. Operation *op = input.getDefiningOp(); - TensorType in_type = input.getType().dyn_cast(); + TensorType in_type = dyn_cast(input.getType()); if (in_type.getElementType() != outType.getElementType()) { TensorType promoted_type = @@ -217,7 +217,7 @@ FailureOr> getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value, ArrayRef inpDims, size_t dimSizeIndexBits) { - auto valueTy = value.getType().dyn_cast(); + auto valueTy = dyn_cast(value.getType()); if (!valueTy) { return rewriter.notifyMatchFailure( op, "getDimSizesOfTensor(): the input is not a ranked tensor"); @@ -240,7 +240,7 @@ FailureOr> getDimSizesOfTensor(PatternRewriter &rewriter, FailureOr> getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value, size_t dimSizeIndexBits) { - auto valueTy = value.getType().dyn_cast(); + auto valueTy = dyn_cast(value.getType()); if (!valueTy) { return rewriter.notifyMatchFailure( op, "getDimSizesOfTensor(): the input is not a ranked tensor"); @@ -279,7 +279,7 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, op, "unsqueeze dimensions must be specified in order"); auto loc = op->getLoc(); - auto rankTy = tensor.getType().dyn_cast(); + auto rankTy = dyn_cast(tensor.getType()); auto oldShape = rankTy.getShape(); Type intType = rewriter.getIntegerType(dimSizeIndexBits); auto one = rewriter.create( diff --git a/lib/Conversion/TorchToStablehlo/ViewLike.cpp b/lib/Conversion/TorchToStablehlo/ViewLike.cpp index fdd482a0d09f..e43105ea1b2b 100644 --- a/lib/Conversion/TorchToStablehlo/ViewLike.cpp +++ b/lib/Conversion/TorchToStablehlo/ViewLike.cpp @@ -72,7 +72,7 @@ Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op, SmallVector endIndices; SmallVector strides; - auto inputTy = input.getType().dyn_cast(); + auto inputTy = dyn_cast(input.getType()); size_t rank = inputTy.getRank(); startIndices.reserve(rank); endIndices.reserve(rank); @@ -116,7 +116,7 @@ FailureOr getDynamicSlice(PatternRewriter &rewriter, Operation *op, std::optional stepOpt, int64_t dim, size_t dimSizeIndexBits) { auto loc = op->getLoc(); - auto inputTy = input.getType().dyn_cast(); + auto inputTy = dyn_cast(input.getType()); auto rank = inputTy.getRank(); dim = (dim + rank) % rank; @@ -168,8 +168,7 @@ class ConvertAtenViewOp : public ConvertAtenOp { LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto rankType = - adaptor.getSelf().getType().template dyn_cast(); + auto rankType = dyn_cast(adaptor.getSelf().getType()); if (!rankType) return op.emitError("Only ranked tensor types are currently supported"); @@ -233,11 +232,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenSliceTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto self = adaptor.getSelf(); - auto selfTy = self.getType().cast(); + auto selfTy = cast(self.getType()); if (!selfTy) return op.emitError("only ranked tensor types are supported"); auto outTy = - getTypeConverter()->convertType(op.getType()).cast(); + cast(getTypeConverter()->convertType(op.getType())); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure( @@ -275,7 +274,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenSqueezeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto self = adaptor.getSelf(); - auto selfTy = self.getType().cast(); + auto selfTy = cast(self.getType()); if (!selfTy) return op.emitError("only ranked tensor types are supported"); @@ -318,7 +317,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenSqueezeDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto self = adaptor.getSelf(); - auto selfTy = self.getType().cast(); + auto selfTy = cast(self.getType()); if (!selfTy) return op.emitError("only ranked tensor types are supported"); @@ -369,7 +368,7 @@ template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenUnsqueezeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType) { return op.emitError("only tensor types are currently supported"); } @@ -378,7 +377,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return op->emitError("dim must be a Scalar constant"); int64_t inputRank = - adaptor.getSelf().getType().cast().getRank(); + cast(adaptor.getSelf().getType()).getRank(); dim = toPositiveDim(dim, inputRank + 1); if (!isValidDim(dim, inputRank + 1)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); @@ -397,7 +396,7 @@ template <> LogicalResult ConvertAtenOp::matchAndRewrite( PrimsCollapseOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto selfType = adaptor.getA().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getA().getType()); if (!selfType) { return op.emitError("only tensor types are currently supported"); } diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index 0ee49e22e72f..3cf821944bc4 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -89,8 +89,8 @@ convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(PatternRewriter &rewriter, Value indices, Value src, int64_t dim) { // Get information on types for inputs - RankedTensorType indexType = indices.getType().cast(); - RankedTensorType srcSelf = src.getType().cast(); + RankedTensorType indexType = cast(indices.getType()); + RankedTensorType srcSelf = cast(src.getType()); // Store location for insertions Location loc = src.getLoc(); @@ -219,7 +219,7 @@ static Value createTMTensorScatterOp( llvm::ArrayRef dimensionsMap, bool uniqueIndices, function_ref bodyBuild) { auto dimensionsMapAttr = b.getDenseI64ArrayAttr(dimensionsMap); - auto originalTensorType = original.getType().cast(); + auto originalTensorType = cast(original.getType()); Type originalElementType = originalTensorType.getElementType(); auto scatterOp = b.create( loc, originalTensorType, ValueRange{updates, indices}, @@ -241,8 +241,8 @@ static Value createTMTensorScanOp( OpBuilder &b, Location loc, Value input, Value output, Value accumulator, int64_t dim, bool inclusive, function_ref bodyBuild) { - auto inputType = input.getType().cast(); - auto accType = accumulator.getType().cast(); + auto inputType = cast(input.getType()); + auto accType = cast(accumulator.getType()); Type elementType = inputType.getElementType(); auto scanOp = b.create( loc, TypeRange{inputType, accType}, input, @@ -287,7 +287,7 @@ createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc, // Step 3. Create comparison op which will be used as the sorting predicate. Value compareOp; - if (auto intType = elementTypes[0].dyn_cast()) { + if (auto intType = dyn_cast(elementTypes[0])) { // Case for using arith::CmpIOp. arith::CmpIPredicate ge = arith::CmpIPredicate::sge; arith::CmpIPredicate le = arith::CmpIPredicate::sle; @@ -329,9 +329,9 @@ class ConvertAtenScatterSrcOp : public OpConversionPattern { Value index = adaptor.getIndex(); Value src = adaptor.getSrc(); - RankedTensorType selfType = self.getType().cast(); - RankedTensorType indexType = index.getType().cast(); - RankedTensorType srcType = src.getType().cast(); + RankedTensorType selfType = cast(self.getType()); + RankedTensorType indexType = cast(index.getType()); + RankedTensorType srcType = cast(src.getType()); if (selfType.getRank() != indexType.getRank() || indexType.getRank() != srcType.getRank()) return rewriter.notifyMatchFailure(op, @@ -385,7 +385,7 @@ class ConvertAtenBincountOp : public OpConversionPattern { // TODO: Add a check to verify that the input tensor elements are all // non-negative. // Check whether the input is a 1-d tensor of integer type or not. - RankedTensorType inputType = input.getType().cast(); + RankedTensorType inputType = cast(input.getType()); if (inputType.getRank() != 1 || !inputType.getElementType().isa()) return rewriter.notifyMatchFailure( @@ -394,7 +394,7 @@ class ConvertAtenBincountOp : public OpConversionPattern { // Check whether the input tensor element type is i64 or not. IntegerType inputIntegerType = - inputType.getElementType().cast(); + cast(inputType.getElementType()); if (inputIntegerType.getWidth() != 64) return rewriter.notifyMatchFailure( op, @@ -409,7 +409,7 @@ class ConvertAtenBincountOp : public OpConversionPattern { SmallVector maxTensorSizes; ValueTensorType maxTensorType = ValueTensorType::get( context, llvm::ArrayRef(maxTensorSizes), - torchTypeInput.getType().cast().getDtype()); + cast(torchTypeInput.getType()).getDtype()); Value maxTensor = rewriter.create(loc, maxTensorType, torchTypeInput); maxTensor = typeConverter->materializeTargetConversion( @@ -432,7 +432,7 @@ class ConvertAtenBincountOp : public OpConversionPattern { makeShapeTorchCompatible(inputType.getShape())[0], 1}; ValueTensorType expandInputType = ValueTensorType::get( context, llvm::ArrayRef(expandedInputSizes), - torchTypeInput.getType().cast().getDtype()); + cast(torchTypeInput.getType()).getDtype()); Value torchCstOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value expandedInputTensor = rewriter.create( @@ -571,7 +571,7 @@ Value combinePutIndices(Location loc, llvm::ArrayRef indicesRef, } BaseTensorType unsqueezedTensorType = - indices[0].getType().cast(); + cast(indices[0].getType()); Value indicesTorchList = b.create( loc, Torch::ListType::get(unsqueezedTensorType), indices); llvm::SmallVector concatShape{ @@ -691,7 +691,7 @@ class ConvertAten_IndexPutImplOp auto inputType = cast(input.getType()); auto valuesType = cast(values.getType()); int64_t inputRank = inputType.getSizes().size(); - auto valuesTensorType = op.getValues().getType().cast(); + auto valuesTensorType = cast(op.getValues().getType()); auto resultType = typeConverter->convertType(op->getResult(0).getType()) .cast(); @@ -902,9 +902,9 @@ class ConvertAtenMaxPool2dWithIndicesBackwardOp Value gradOutput = adaptor.getGradOutput(); Value input = adaptor.getSelf(); RankedTensorType gradOutputType = - gradOutput.getType().cast(); + cast(gradOutput.getType()); Type gradOutputElemType = gradOutputType.getElementType(); - RankedTensorType inputType = input.getType().cast(); + RankedTensorType inputType = cast(input.getType()); Type inputElemType = inputType.getElementType(); int64_t tensorOperandRank = inputType.getRank(); @@ -914,7 +914,7 @@ class ConvertAtenMaxPool2dWithIndicesBackwardOp mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed)); indices = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(indices.getType()), indices); - RankedTensorType indicesType = indices.getType().cast(); + RankedTensorType indicesType = cast(indices.getType()); Type indicesElemType = indicesType.getElementType(); // The element type of the `input` and `grad_output` should be same. @@ -1100,11 +1100,11 @@ class ConvertAtenScatterReduceTwoOp Location loc = op.getLoc(); RankedTensorType selfType = - adaptor.getSelf().getType().cast(); + cast(adaptor.getSelf().getType()); RankedTensorType indexType = - adaptor.getIndex().getType().cast(); + cast(adaptor.getIndex().getType()); RankedTensorType srcType = - adaptor.getSrc().getType().cast(); + cast(adaptor.getSrc().getType()); Value self = adaptor.getSelf(); @@ -1324,7 +1324,7 @@ class ConvertAtenSortOp : public OpConversionPattern { // Step 1. Fetch Input to sort. Value inputTensor = adaptor.getSelf(); - auto inputType = inputTensor.getType().cast(); + auto inputType = cast(inputTensor.getType()); unsigned inputRank = inputType.getRank(); // Step 2. Fetch dimension to perform sort in. @@ -1414,7 +1414,7 @@ class ConvertAtenCumsumOp : public OpConversionPattern { .cast(); Type elementType = resultType.getElementType(); Type inputElementType = - input.getType().cast().getElementType(); + cast(input.getType()).getElementType(); // Converting the input element type to the result's element type. // The only possible mismatch would be when the input element type is an @@ -1486,7 +1486,7 @@ class ConvertAtenScaledDotProductAttentionOp Value isCausal = op.getIsCausal(); Value scale = op.getScale(); Type elementType = - adaptor.getQuery().getType().cast().getElementType(); + cast(adaptor.getQuery().getType()).getElementType(); // Verify inputs (only support defaults) if (!mask.getType().isa()) @@ -1557,10 +1557,9 @@ class ConvertAtenScaledDotProductAttentionOp key = collapseBatch(key); value = collapseBatch(value); - SmallVector outSizes( - query.getType().cast().getShape()); + SmallVector outSizes(cast(query.getType()).getShape()); SmallVector valueSizes( - value.getType().cast().getShape()); + cast(value.getType()).getShape()); outSizes[outSizes.size() - 1] = valueSizes[valueSizes.size() - 1]; SmallVector outSizesDynamic( getTensorSizes(rewriter, op.getLoc(), query)); diff --git a/lib/Conversion/TorchToTensor/TorchToTensor.cpp b/lib/Conversion/TorchToTensor/TorchToTensor.cpp index d2663b3658f0..f3ec5c01095f 100644 --- a/lib/Conversion/TorchToTensor/TorchToTensor.cpp +++ b/lib/Conversion/TorchToTensor/TorchToTensor.cpp @@ -79,9 +79,9 @@ class ConvertAtenShapeToTensorPatternOp ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto operand = adaptor.getOperands()[0]; - auto operandTy = operand.getType().cast(); + auto operandTy = cast(operand.getType()); auto resultTy = - getTypeConverter()->convertType(op.getType()).cast(); + cast(getTypeConverter()->convertType(op.getType())); int64_t rank = operandTy.getRank(); if (rank == 0) { diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 71bc0b51e0fc..010e7fce01dc 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -43,7 +43,7 @@ class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value self = adaptor.getSelf(); - auto selfTy = self.getType().cast(); + auto selfTy = cast(self.getType()); if (!selfTy) return rewriter.notifyMatchFailure(op, @@ -93,9 +93,9 @@ class ConvertAtenBinaryOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.getSelf(); - auto lhsTy = lhs.getType().cast(); + auto lhsTy = cast(lhs.getType()); Value rhs = adaptor.getOther(); - auto rhsTy = rhs.getType().cast(); + auto rhsTy = cast(rhs.getType()); if (!lhsTy || !rhsTy) return rewriter.notifyMatchFailure(op, @@ -235,15 +235,15 @@ class ConvertAtenAddSubOp : public OpConversionPattern { // alpha : scalar: i32/i64/f32 // output: tensor: tensor Value lhs = adaptor.getSelf(); - auto lhsType = lhs.getType().dyn_cast(); + auto lhsType = dyn_cast(lhs.getType()); Value rhs = adaptor.getOther(); - auto rhsType = rhs.getType().dyn_cast(); + auto rhsType = dyn_cast(rhs.getType()); if (!lhsType) return rewriter.notifyMatchFailure(op, "Only Tensor types supported in TOSA"); - if (auto lhsElemTy = lhsType.getElementType().dyn_cast()) { + if (auto lhsElemTy = dyn_cast(lhsType.getElementType())) { if (lhsElemTy.getWidth() > 64) return rewriter.notifyMatchFailure( op, "Integers with widths greater than 64 are not supported"); @@ -284,7 +284,7 @@ class ConvertAtenAddSubOp : public OpConversionPattern { op->getLoc(), RankedTensorType::get(rhsType.getShape(), rhsAlphaMulElemType), rhs); // reinitialize right value type to tensor - rhsType = rhs.getType().dyn_cast(); + rhsType = dyn_cast(rhs.getType()); } auto rhsTensor = rhsType ? rhs : rhsAsTensor; @@ -337,9 +337,9 @@ class ConvertAtenCompareOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.getSelf(); - auto lhsTy = lhs.getType().dyn_cast(); + auto lhsTy = dyn_cast(lhs.getType()); Value rhs = adaptor.getOther(); - auto rhsTy = rhs.getType().dyn_cast(); + auto rhsTy = dyn_cast(rhs.getType()); if (!lhsTy) return rewriter.notifyMatchFailure(op, @@ -409,7 +409,7 @@ class ConvertAtenMulOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.getSelf(); - auto lhsType = lhs.getType().dyn_cast(); + auto lhsType = dyn_cast(lhs.getType()); if (!lhsType) return rewriter.notifyMatchFailure(op, @@ -430,7 +430,7 @@ class ConvertAtenMulOp : public OpConversionPattern { } else { Value rhsAsTensor; Value rhs = adaptor.getOther(); - auto rhsType = rhs.getType().dyn_cast(); + auto rhsType = dyn_cast(rhs.getType()); if (!rhsType) { if (failed(torchScalarToTosaTensor(rewriter, op, op.getOther(), rhsAsTensor, outElemTy, {}))) { @@ -469,9 +469,9 @@ class ConvertAtenDivOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.getSelf(); - auto lhsTy = lhs.getType().dyn_cast(); + auto lhsTy = dyn_cast(lhs.getType()); Value rhs = adaptor.getOther(); - auto rhsTy = rhs.getType().dyn_cast(); + auto rhsTy = dyn_cast(rhs.getType()); if (!lhsTy) return rewriter.notifyMatchFailure(op, @@ -497,7 +497,7 @@ class ConvertAtenDivOp : public OpConversionPattern { // auto result; Value result; - if (outType.getElementType().template isa()) { + if (isa(outType.getElementType())) { // The input to the reciprocal is an integer sometimes, and we may need to // promote it to a floating point. Per TOSA specification, the input types // can only be floating point for tosa::ReciprocalOp. @@ -538,7 +538,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenTanhOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); - auto selfTy = self.getType().cast(); + auto selfTy = cast(self.getType()); if (selfTy && selfTy.getElementType().isa()) { rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self); @@ -555,7 +555,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenSigmoidOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); - auto selfTy = self.getType().cast(); + auto selfTy = cast(self.getType()); if (selfTy && selfTy.getElementType().isa()) { rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self); @@ -572,7 +572,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenReluOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); - auto selfTy = self.getType().cast(); + auto selfTy = cast(self.getType()); // Maps to tosa.clamp which has both int and fp limits. int64_t clampMin = 0; @@ -602,7 +602,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); - auto selfTy = self.getType().cast(); + auto selfTy = cast(self.getType()); if (!selfTy.getElementType().isa()) { return rewriter.notifyMatchFailure( op, "Only floating-point datatype legalization currently supported"); @@ -660,7 +660,7 @@ class ConvertAtenReductionOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value self = adaptor.getSelf(); - auto selfTy = self.getType().cast(); + auto selfTy = cast(self.getType()); if (!selfTy) return rewriter.notifyMatchFailure(op, @@ -713,7 +713,7 @@ class ConvertAtenMultipleDimsReductionOp "non-const dim parameter unsupported"); int64_t N = reduceDims.size(); int64_t inputRank = - adaptor.getSelf().getType().template cast().getRank(); + cast(adaptor.getSelf().getType()).getRank(); for (unsigned i = 0; i < N; i++) { reduceDims[i] = toPositiveDim(reduceDims[i], inputRank); if (!isValidDim(reduceDims[i], inputRank)) @@ -751,7 +751,7 @@ class ConvertAtenOneDimReductionOp return rewriter.notifyMatchFailure(op, "non-const dim parameter unsupported"); int64_t inputRank = - adaptor.getSelf().getType().template cast().getRank(); + cast(adaptor.getSelf().getType()).getRank(); reduceDim = toPositiveDim(reduceDim, inputRank); if (!isValidDim(reduceDim, inputRank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); @@ -782,7 +782,7 @@ class ConvertAtenAllDimsReductionOp ElementsAttr &reduceDimsAttr, bool &keepDims) const override { auto self = adaptor.getSelf(); - auto selfTy = self.getType().template cast(); + auto selfTy = cast(self.getType()); // Select all dims to reduce SmallVector reduceDims; @@ -804,7 +804,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); - auto selfTy = self.getType().template cast(); + auto selfTy = cast(self.getType()); if (!selfTy) return rewriter.notifyMatchFailure( @@ -835,7 +835,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Create a single instance of tosa.argmax. // Multiple dims require chained construct. auto buildArgmax = [&](int64_t reduceDim, Value input) -> Value { - auto inputTy = input.getType().cast(); + auto inputTy = cast(input.getType()); auto inputShape = makeShapeTorchCompatible(inputTy.getShape()); SmallVector outputShapeArr = {}; int32_t i = 0; @@ -865,7 +865,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Convert the final index to i64 for backend finalization, However, i64 // is not a defined type for tosa.cast, so using arith.extsi instead. auto castToInt64 = [&](Value result) -> LogicalResult { - auto resTy = result.getType().cast(); + auto resTy = cast(result.getType()); if (!resTy) return rewriter.notifyMatchFailure(op, "Argmax: Result is not a shaped type"); @@ -915,7 +915,7 @@ class ConvertAtenSqueezeOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value self = adaptor.getSelf(); - auto selfTy = self.getType().template cast(); + auto selfTy = cast(self.getType()); if (!selfTy) return rewriter.notifyMatchFailure( @@ -1010,7 +1010,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); - auto selfTy = self.getType().template cast(); + auto selfTy = cast(self.getType()); if (!selfTy) return rewriter.notifyMatchFailure( @@ -1021,7 +1021,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Only floating-point datatype legalization supported"); auto outType = - getTypeConverter()->convertType(op.getType()).template cast(); + cast(getTypeConverter()->convertType(op.getType())); Value expTensor; Value expScalar = op.getExponent(); @@ -1063,8 +1063,8 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { ConversionPatternRewriter &rewriter, Value &lhs, Value &rhs, Value &output) const { - auto lhsTy = lhs.getType().cast(); - auto rhsTy = rhs.getType().cast(); + auto lhsTy = cast(lhs.getType()); + auto rhsTy = cast(rhs.getType()); auto lhsRank = lhsTy.getRank(); auto rhsRank = rhsTy.getRank(); @@ -1097,7 +1097,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { // construct the input and output reshaping logic. auto getRankBroadcastedShape = [&](Value tensor, bool isRHS) -> SmallVector { - auto tensorTy = tensor.getType().cast(); + auto tensorTy = cast(tensor.getType()); auto tensorShape = makeShapeTorchCompatible(tensorTy.getShape()); auto tensorRank = tensorTy.getRank(); @@ -1151,7 +1151,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { // TOSA matmul is performed on two 3D inputs and generates a 3D output. // Lower ranked tensors are dim-1 reshaped up to 3D auto reshapeUpTo3DTensor = [&](Value tensor) -> Value { - auto tensorTy = tensor.getType().cast(); + auto tensorTy = cast(tensor.getType()); auto rank = tensorTy.getRank(); assert(rank <= 3 && "reshapeUpTo3D tensor must receive rank <= 3"); @@ -1440,9 +1440,9 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { } auto matmulLhsShape = makeShapeTorchCompatible( - matmulLhs.getType().template cast().getShape()); + cast(matmulLhs.getType()).getShape()); auto matmulRhsShape = makeShapeTorchCompatible( - matmulRhs.getType().template cast().getShape()); + cast(matmulRhs.getType()).getShape()); // The reshape/transpose should ensure the tosa.matmul always has same // batch size for either matrix. If if shapes are dynamic, they'll be @@ -1642,10 +1642,10 @@ class ConvertAtenMatMulOp : public ConvertAtenMatmulBaseOp { ConversionPatternRewriter &rewriter, Value &lhs, Value &rhs) const override { lhs = adaptor.getSelf(); - auto lhsTy = lhs.getType().cast(); + auto lhsTy = cast(lhs.getType()); rhs = adaptor.getOther(); - auto rhsTy = rhs.getType().cast(); + auto rhsTy = cast(rhs.getType()); if (!lhsTy || !rhsTy) return rewriter.notifyMatchFailure( @@ -1666,10 +1666,10 @@ class ConvertAtenMmOp : public ConvertAtenMatmulBaseOp { Value &lhs, Value &rhs) const override { lhs = adaptor.getSelf(); - auto lhsTy = lhs.getType().cast(); + auto lhsTy = cast(lhs.getType()); rhs = adaptor.getMat2(); - auto rhsTy = rhs.getType().cast(); + auto rhsTy = cast(rhs.getType()); if (!lhsTy || !rhsTy) return rewriter.notifyMatchFailure( @@ -1703,10 +1703,10 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { Value &lhs, Value &rhs) const override { lhs = adaptor.getInput(); - auto lhsTy = lhs.getType().cast(); + auto lhsTy = cast(lhs.getType()); rhs = adaptor.getWeight(); - auto rhsTy = rhs.getType().cast(); + auto rhsTy = cast(rhs.getType()); if (!lhsTy || !rhsTy) return rewriter.notifyMatchFailure( @@ -1744,14 +1744,13 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { auto biasTy = bias.getType(); // TOSA does not mandate that elementwise op tensors need to be ranked. - if (!biasTy.template isa() && - !biasTy.template isa()) + if (!isa(biasTy) && !isa(biasTy)) return rewriter.notifyMatchFailure( op, "Only tensor types supported in GEMM to TOSA for bias tensor"); // RHS must have its last two dims transposed prior to matrix // multiplication. - auto rhsTy = rhs.getType().cast(); + auto rhsTy = cast(rhs.getType()); auto rhsRank = rhsTy.getRank(); auto rhsShape = makeShapeTorchCompatible(rhsTy.getShape()); auto rhsElemTy = rhsTy.getElementType(); @@ -1789,7 +1788,7 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { "Failed to perform matmul operation"); Value matmulPlusBias = matmulOutput; - if (!biasTy.template isa()) { + if (!isa(biasTy)) { // Bias addition broadcasts to the matmul output shape. matmulPlusBias = rewriter @@ -1818,7 +1817,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto otherScalar = op.getOther(); auto alphaScalar = op.getAlpha(); - auto selfTy = self.getType().cast(); + auto selfTy = cast(self.getType()); if (!selfTy) return rewriter.notifyMatchFailure( op, "Only ranked tensor types supported in TOSA Rsub"); @@ -1867,8 +1866,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto input = adaptor.getInput(); auto weight = adaptor.getWeight(); - auto inputTy = input.getType().cast(); - auto weightTy = weight.getType().cast(); + auto inputTy = cast(input.getType()); + auto weightTy = cast(weight.getType()); auto outputTy = getTypeConverter() ->convertType(op.getType()) .template cast(); @@ -1893,7 +1892,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Bias is optional. TOSA mandates a zero tensor here, so construct one if // required. auto bias = adaptor.getBias(); - if (adaptor.getBias().getType().template isa()) { + if (isa(adaptor.getBias().getType())) { // TBD: This is only valid for quantized 8-bit. For 16-bit, the bias (and // accumulator) are 48-bit and not 32-bit, and requires the use of APInt to // define a 48-bit int. @@ -1909,7 +1908,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .value(); } } else { - if (!bias.getType().cast()) + if (!cast(bias.getType())) return rewriter.notifyMatchFailure( op, "Bias provided but not a ranked tensor"); } @@ -2115,7 +2114,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto self = adaptor.getSelf(); - auto selfTy = self.getType().cast(); + auto selfTy = cast(self.getType()); if (!selfTy) return rewriter.notifyMatchFailure( op, "Only ranked tensor types supported in TOSA Reshape"); @@ -2199,7 +2198,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a ranked tensor output - if (!adaptor.getInput().getType().dyn_cast()) + if (!dyn_cast(adaptor.getInput().getType())) return rewriter.notifyMatchFailure( op, "Only ranked tensor types are supported"); @@ -2211,8 +2210,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (op.getMomentum().getType().isa()) return rewriter.notifyMatchFailure(op, "Unsupported None for momentum"); - auto meanType = adaptor.getRunningMean().getType().dyn_cast(); - auto varianceType = adaptor.getRunningVar().getType().dyn_cast(); + auto meanType = dyn_cast(adaptor.getRunningMean().getType()); + auto varianceType = dyn_cast(adaptor.getRunningVar().getType()); if (!varianceType || !meanType) return rewriter.notifyMatchFailure( op, "Only ranked tensor types are supported"); @@ -2225,7 +2224,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( const TypeConverter *converter, Type outType, const Value toBcast, Value &result) { RankedTensorType toBcastType = - toBcast.getType().dyn_cast(); + dyn_cast(toBcast.getType()); if (toBcastType.getRank() > 1) return rewriter.notifyMatchFailure(op, "Rank cannot be more than 1"); @@ -2298,11 +2297,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // eventually being reshaped for broadcasting. // Not a ranked tensor output - if (!adaptor.getInput().getType().dyn_cast()) + if (!dyn_cast(adaptor.getInput().getType())) return rewriter.notifyMatchFailure( op, "Only ranked tensor types are supported"); - auto inputType = adaptor.getInput().getType().cast(); + auto inputType = cast(adaptor.getInput().getType()); if (inputType.getRank() > 4) return rewriter.notifyMatchFailure(op, "Only up to 4D tensors are supported"); @@ -2317,8 +2316,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (adaptor.getBias().getType().isa()) return rewriter.notifyMatchFailure(op, "Unsupported None for bias"); - auto weightType = adaptor.getWeight().getType().cast(); - auto biasType = adaptor.getBias().getType().cast(); + auto weightType = cast(adaptor.getWeight().getType()); + auto biasType = cast(adaptor.getBias().getType()); int64_t inputRank = inputType.getRank(); Type elemTy = inputType.getElementType(); SmallVector inputTypeShape( @@ -2461,7 +2460,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // element type. All tensors with element types other than integer can reuse // existing elements attribute. // TODO: what about unsigned integer? - if (auto elements = op.getValueAttr().dyn_cast()) { + if (auto elements = dyn_cast(op.getValueAttr())) { if (elements.getElementType().isSignedInteger()) { Type builtinTensorElemTy = outputTy.getElementType(); unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth(); @@ -2483,7 +2482,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a ranked tensor type - auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType) return rewriter.notifyMatchFailure(op, "Only ranked tensor types supported"); @@ -2548,7 +2547,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a ranked tensor type - auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType || !selfType.hasStaticShape()) return rewriter.notifyMatchFailure( op, @@ -2602,7 +2601,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a ranked tensor type - auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType) return rewriter.notifyMatchFailure( op, @@ -2637,7 +2636,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a tensor type. - auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType) return rewriter.notifyMatchFailure( op, "Only tensor types are currently supported"); @@ -2665,7 +2664,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a tensor type. - auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType) return rewriter.notifyMatchFailure( op, "Only tensor types are currently supported"); @@ -2715,7 +2714,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a tensor type. - auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType) { return rewriter.notifyMatchFailure( op, "Only tensor types are currently supported"); @@ -2763,7 +2762,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a tensor type. - auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType) return rewriter.notifyMatchFailure( op, "Only tensor types are currently supported"); @@ -2781,7 +2780,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a tensor type. - auto selfType = adaptor.getInput().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getInput().getType()); if (!selfType) return rewriter.notifyMatchFailure( op, "Only tensor types are currently supported"); @@ -2807,7 +2806,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a tensor type. - auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType) return rewriter.notifyMatchFailure( op, "Only tensor types are currently supported"); @@ -2869,7 +2868,7 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter, // // Erf = 1 - 1 / (1 + a1X + a2X + a3X + a4X)^4 - auto outType = x.getType().cast(); + auto outType = cast(x.getType()); auto loc = op->getLoc(); auto absX = rewriter.create(loc, outType, x); auto zero = tosa::getConstTensor(rewriter, op, 0, {}, dtype).value(); @@ -2949,7 +2948,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a tensor type. - auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType) return rewriter.notifyMatchFailure( op, "Only tensor types are currently supported"); @@ -2986,7 +2985,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a tensor type. - auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType) return rewriter.notifyMatchFailure( op, "Only tensor types are currently supported"); @@ -3043,7 +3042,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a tensor type. - auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType) { return rewriter.notifyMatchFailure( op, "Only tensor types are currently supported"); @@ -3063,7 +3062,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } Value gradOutput = adaptor.getGradOutput(); - auto gradOutputType = adaptor.getSelf().getType().dyn_cast(); + auto gradOutputType = dyn_cast(adaptor.getSelf().getType()); Type gradOutputElemType = gradOutputType.getElementType(); @@ -3119,14 +3118,14 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value weight = adaptor.getWeight(); Value indices = adaptor.getIndices(); RankedTensorType outType = - typeConverter->convertType(op.getType()).cast(); + cast(typeConverter->convertType(op.getType())); - auto indicesType = indices.getType().dyn_cast(); + auto indicesType = dyn_cast(indices.getType()); if (!indicesType || !indicesType.getElementType().isa()) return rewriter.notifyMatchFailure( op, "Indices must be of integer tensor type"); - auto weightType = weight.getType().cast(); + auto weightType = cast(weight.getType()); if (weightType.getRank() != 2) return op.emitError("weight must be of rank 2"); @@ -3216,7 +3215,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenTransposeIntOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType) return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); @@ -3258,12 +3257,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenMaxDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType) return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); auto indicesType = - getTypeConverter()->convertType(op.getType(1)).dyn_cast(); + dyn_cast(getTypeConverter()->convertType(op.getType(1))); if (!indicesType) return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); @@ -3334,7 +3333,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenSliceTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType || !selfType.hasStaticShape()) return rewriter.notifyMatchFailure( op, "Only tensor types with static shape are supported"); @@ -3406,7 +3405,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a tensor type. - auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType || !selfType.hasStaticShape()) return rewriter.notifyMatchFailure( op, "Only tensor types with static shape are supported"); @@ -3500,13 +3499,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a tensor type. auto input = adaptor.getSelf(); - auto inputType = adaptor.getSelf().getType().dyn_cast(); + auto inputType = dyn_cast(adaptor.getSelf().getType()); if (!inputType) return rewriter.notifyMatchFailure( op, "Only RankedTensorType input are currently supported"); auto index = adaptor.getIndex(); - auto indexType = adaptor.getIndex().getType().dyn_cast(); + auto indexType = dyn_cast(adaptor.getIndex().getType()); auto inputShape = inputType.getShape(); int paramsRank = inputShape.size(); @@ -3593,13 +3592,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a tensor type. auto input = adaptor.getSelf(); - auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType) return rewriter.notifyMatchFailure( op, "Only tensor types input are currently supported"); auto fillValues = adaptor.getValues(); - auto valuesType = adaptor.getValues().getType().dyn_cast(); + auto valuesType = dyn_cast(adaptor.getValues().getType()); if (!valuesType) return rewriter.notifyMatchFailure( op, "Only tensor types input are currently supported"); @@ -3640,7 +3639,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Multiple None index is not support for now."); } - auto indexNextType = indexNext.getType().dyn_cast(); + auto indexNextType = dyn_cast(indexNext.getType()); auto indexNextShape = indexNextType.getShape(); int64_t size = 1; @@ -3652,7 +3651,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .value(); } - auto indexType = index.getType().dyn_cast(); + auto indexType = dyn_cast(index.getType()); auto indexShape = indexType.getShape(); indexesShape.push_back(makeShapeTorchCompatible(indexShape)); indexesRank.push_back(indexType.getRank()); @@ -3734,7 +3733,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // [[16, 17, 18, 19, 20], [11, 12, 13, 14, 15], [ 6, 7, 8, 9, 10]]] auto input = adaptor.getSelf(); auto inputTensorType = - adaptor.getSelf().getType().dyn_cast(); + dyn_cast(adaptor.getSelf().getType()); // Check input is a tensor type. if (!inputTensorType) return rewriter.notifyMatchFailure( @@ -3771,7 +3770,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( for (size_t i = 0; i < indexTensors.size(); i++) { auto index = indexTensors[i]; - auto indexType = index.getType().dyn_cast(); + auto indexType = dyn_cast(index.getType()); auto indexShape = indexType.getShape(); indexesShape.push_back(makeShapeTorchCompatible(indexShape)); indexesRank.push_back(indexType.getRank()); @@ -3837,7 +3836,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Support for multiple index auto index = indexTensors[0]; - auto indexType = index.getType().dyn_cast(); + auto indexType = dyn_cast(index.getType()); auto indexShape = indexType.getShape(); // index i64 to i32 for tosa compatible if (indexType.getElementType() != rewriter.getIntegerType(32)) { @@ -3879,7 +3878,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenAbsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // Not a tensor type. - auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType) return rewriter.notifyMatchFailure( op, "Only tensor types input are currently supported"); @@ -3896,11 +3895,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a tensor type. - auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType) return rewriter.notifyMatchFailure( op, "Only tensor types input are currently supported"); - auto condType = adaptor.getCondition().getType().dyn_cast(); + auto condType = dyn_cast(adaptor.getCondition().getType()); if (!condType) return rewriter.notifyMatchFailure( op, "Only tensor types condition are currently supported"); @@ -3919,11 +3918,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a tensor type. - auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType) return rewriter.notifyMatchFailure( op, "Only tensor types input are currently supported"); - auto otherType = adaptor.getOther().getType().dyn_cast(); + auto otherType = dyn_cast(adaptor.getOther().getType()); if (!otherType) return rewriter.notifyMatchFailure( op, "Only tensor types condition are currently supported"); @@ -3955,8 +3954,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "unimplemented: equal_nan is expected to be false"); // check tensor type. - auto selfType = adaptor.getSelf().getType().dyn_cast(); - auto otherType = adaptor.getOther().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); + auto otherType = dyn_cast(adaptor.getOther().getType()); if (!selfType || !otherType) return rewriter.notifyMatchFailure( op, "Only tensor types input are currently supported"); @@ -3998,7 +3997,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a tensor type. - auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType) return rewriter.notifyMatchFailure( op, "only tensor types input are currently supported"); @@ -4251,8 +4250,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a tensor type. - auto selfType = adaptor.getSelf().getType().dyn_cast(); - auto srcType = adaptor.getSrc().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); + auto srcType = dyn_cast(adaptor.getSrc().getType()); if (!selfType || !selfType.hasStaticShape()) return rewriter.notifyMatchFailure( op, "Only tensor types with static shape are supported"); @@ -4297,7 +4296,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a tensor type. - auto selfType = adaptor.getSelf().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType || !selfType.hasStaticShape()) return rewriter.notifyMatchFailure( op, "Only tensor types with static shape are supported"); @@ -4355,14 +4354,14 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); - auto selfTy = self.getType().template cast(); + auto selfTy = cast(self.getType()); if (!selfTy) return rewriter.notifyMatchFailure( op, "Only ranked tensor types supported in TOSA Remainder"); auto outType = - getTypeConverter()->convertType(op.getType()).template cast(); + cast(getTypeConverter()->convertType(op.getType())); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) @@ -4438,7 +4437,7 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { // Apply the transposeDims vector on input to generate a transposed form. Value transposeTensor(AtenOpT op, ConversionPatternRewriter &rewriter, Value input, ArrayRef transposeDims) const { - auto inputTy = input.getType().template cast(); + auto inputTy = cast(input.getType()); auto inputElemTy = inputTy.getElementType(); auto inputShape = makeShapeTorchCompatible(inputTy.getShape()); auto inputRank = inputTy.getRank(); @@ -4462,8 +4461,7 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { Value transposePoolingInputToHwc(AtenOpT op, ConversionPatternRewriter &rewriter, Value input) const { - auto inputRank = - input.getType().template cast().getRank(); + auto inputRank = cast(input.getType()).getRank(); SmallVector nchwToNhwc4DTransposeDims({0, 2, 3, 1}); SmallVector chwToHwc3DTransposeDims({1, 2, 0}); @@ -4476,7 +4474,7 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { Value transposePoolingOutputToChw(AtenOpT op, ConversionPatternRewriter &rewriter, Value input) const { - auto inputTy = input.getType().template cast(); + auto inputTy = cast(input.getType()); auto inputRank = inputTy.getRank(); SmallVector nhwcToNchw4DTransposeDims({0, 3, 1, 2}); @@ -4547,7 +4545,7 @@ class ConvertAtenAdaptivePoolingOp DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad, Type &outputTy) const override { auto inputXchw = adaptor.getSelf(); - auto inputTy = inputXchw.getType().template cast(); + auto inputTy = cast(inputXchw.getType()); if (!inputTy) return rewriter.notifyMatchFailure( op, "Adaptive avgpool requires ranked tensor input"); @@ -4659,7 +4657,7 @@ static LogicalResult getOutputTypeAndPoolingParameters( DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad) { - RankedTensorType inputTy = inputXchw.getType().cast(); + RankedTensorType inputTy = cast(inputXchw.getType()); if (!inputTy) return rewriter.notifyMatchFailure( op, "Pooling op requires ranked tensor input"); @@ -4797,7 +4795,7 @@ class ConvertAtenConstPatternOp : public OpConversionPattern { // FIXME: Handle layout, device and pin_memory. Assume dtype has been // processed to set output type correctly? // The layout arg should be either `none` or `0` i.e. strided. - if (!op.getLayout().getType().template isa()) { + if (!isa(op.getLayout().getType())) { int64_t tensorLayout; if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout))) return rewriter.notifyMatchFailure( @@ -4808,7 +4806,7 @@ class ConvertAtenConstPatternOp : public OpConversionPattern { } bool pinMemory; - if (!op.getPinMemory().getType().template isa() && + if (!isa(op.getPinMemory().getType()) && (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) || pinMemory)) { return rewriter.notifyMatchFailure( @@ -4892,19 +4890,19 @@ class ConvertAtenMaskedFillOp : public OpConversionPattern { } // Not a tensor type. - auto selfType = adaptor.getSelf().getType().template dyn_cast(); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType || !outType.hasStaticShape()) return rewriter.notifyMatchFailure( op, "Only tensor types with static shapes input are currently supported"); - auto maskType = adaptor.getMask().getType().template dyn_cast(); + auto maskType = dyn_cast(adaptor.getMask().getType()); if (!maskType) return rewriter.notifyMatchFailure( op, "Only tensor types mask are currently supported"); Value rhs = adaptor.getValue(); - auto rhsType = rhs.getType().template dyn_cast(); + auto rhsType = dyn_cast(rhs.getType()); Value rhsAsTensor; if (!rhsType) { // scalar if (failed(torchScalarToTosaTensor(rewriter, op, op.getValue(), @@ -4913,11 +4911,11 @@ class ConvertAtenMaskedFillOp : public OpConversionPattern { op, "Currently only scalar constants are supported for " "conversion in TOSA operation"); } else { // tensor - rhsType = rhs.getType().dyn_cast(); + rhsType = dyn_cast(rhs.getType()); } auto rhsTensor = rhsType ? rhs : rhsAsTensor; - auto rhsTensorType = rhsTensor.getType().template dyn_cast(); + auto rhsTensorType = dyn_cast(rhsTensor.getType()); if (rhsTensorType.getElementType() != outElemTy) rhsTensor = rewriter.create( op.getLoc(), @@ -4940,7 +4938,7 @@ class ConvertAtenCloneOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { int64_t memoryFormat; - if (!op.getMemoryFormat().getType().template isa() && + if (!isa(op.getMemoryFormat().getType()) && (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat)) || (memoryFormat != torch_upstream::MemoryFormat::Contiguous && @@ -4964,7 +4962,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); Value self = adaptor.getSelf(); - auto selfTy = self.getType().cast(); + auto selfTy = cast(self.getType()); auto selfElemTy = selfTy.getElementType(); int64_t rank = selfTy.getRank(); @@ -5033,7 +5031,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { const TypeConverter *typeConverter = this->getTypeConverter(); auto outType = - typeConverter->convertType(op.getType()).cast(); + cast(typeConverter->convertType(op.getType())); int64_t rank = outType.getRank(); int64_t dim; @@ -5074,7 +5072,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Converts AtenSqrtOp into pow(x, 0.5) auto self = adaptor.getSelf(); - auto selfTy = self.getType().dyn_cast(); + auto selfTy = dyn_cast(self.getType()); if (!selfTy) return rewriter.notifyMatchFailure(op, "Only Tensor types supported in TOSA"); diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index e06629dd3ea4..f9d3071fd10c 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -117,8 +117,8 @@ template <> tosa::DivOp createBinaryOpAndCast(PatternRewriter &rewriter, Operation *op, TensorType outType, Value lhs, Value rhs) { - auto lhsElemTy = lhs.getType().cast().getElementType(); - auto rhsElemTy = rhs.getType().cast().getElementType(); + auto lhsElemTy = cast(lhs.getType()).getElementType(); + auto rhsElemTy = cast(rhs.getType()).getElementType(); if (isa(lhsElemTy) || isa(rhsElemTy)) { (void)rewriter.notifyMatchFailure(op, "tosa.div only supports integer type"); @@ -148,8 +148,8 @@ std::optional convertTorchIndexToTfIndices(PatternRewriter &rewriter, // [2,1] [[0, 3, 2],[0, 3, 1]] // ]] 1*4*2 ]] 1*4*2*3 - auto paramsType = paramsValue.getType().dyn_cast(); - auto indexType = indexValue.getType().dyn_cast(); + auto paramsType = dyn_cast(paramsValue.getType()); + auto indexType = dyn_cast(indexValue.getType()); auto paramsShape = paramsType.getShape(); // [1 4 3] auto indexShape = indexType.getShape(); // [1 4 2] int paramsRank = paramsShape.size(); // 3 @@ -214,8 +214,8 @@ std::optional convertGatherNdOp(PatternRewriter &rewriter, Operation *op, Type outType, Value paramsValue, Value indicesValue) { auto resultType = dyn_cast(outType); - auto paramsType = paramsValue.getType().dyn_cast(); - auto indicesType = indicesValue.getType().dyn_cast(); + auto paramsType = dyn_cast(paramsValue.getType()); + auto indicesType = dyn_cast(indicesValue.getType()); if (!resultType || !paramsType || !indicesType) return std::nullopt; @@ -420,9 +420,9 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, Value paramsValue, Value indicesValue, Value fillValues) { auto resultType = dyn_cast(outType); - auto paramsType = paramsValue.getType().dyn_cast(); - auto indicesType = indicesValue.getType().dyn_cast(); - auto fillValuesType = fillValues.getType().dyn_cast(); + auto paramsType = dyn_cast(paramsValue.getType()); + auto indicesType = dyn_cast(indicesValue.getType()); + auto fillValuesType = dyn_cast(fillValues.getType()); if (!resultType || !paramsType || !indicesType) return std::nullopt; @@ -572,7 +572,7 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, tosaFillValuesTileOp.getResult(), rewriter.getDenseI64ArrayAttr(newTosaFillValuesShape)); fillValues = newTosaFillValuesReshapeOp.getResult(); - fillValuesType = fillValues.getType().dyn_cast(); + fillValuesType = dyn_cast(fillValues.getType()); } // fillK: range of each index, total number of fillInput(could be scatter) @@ -691,7 +691,7 @@ std::optional convertReduceOpCommon( Type reduce_element_type, bool is_quantized, double input_scale, int64_t input_zp, double output_scale, int64_t output_zp) { RankedTensorType input_type = - input_value.getType().dyn_cast(); + dyn_cast(input_value.getType()); if (!input_type) return std::nullopt; @@ -754,7 +754,7 @@ convertReduceAllOp(PatternRewriter &rewriter, Operation *op, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, bool keep_dims) { RankedTensorType input_type = - input_value.getType().dyn_cast(); + dyn_cast(input_value.getType()); if (!input_type) return std::nullopt; @@ -769,7 +769,7 @@ convertReduceAnyOp(PatternRewriter &rewriter, Operation *op, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, bool keep_dims) { RankedTensorType input_type = - input_value.getType().dyn_cast(); + dyn_cast(input_value.getType()); if (!input_type) return std::nullopt; @@ -784,7 +784,7 @@ convertReduceMinOp(PatternRewriter &rewriter, Operation *op, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, bool keep_dims) { RankedTensorType input_type = - input_value.getType().dyn_cast(); + dyn_cast(input_value.getType()); if (!input_type) return std::nullopt; @@ -799,7 +799,7 @@ convertReduceMaxOp(PatternRewriter &rewriter, Operation *op, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, bool keep_dims) { RankedTensorType input_type = - input_value.getType().dyn_cast(); + dyn_cast(input_value.getType()); if (!input_type) return std::nullopt; @@ -814,7 +814,7 @@ convertReduceProdOp(PatternRewriter &rewriter, Operation *op, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, bool keep_dims) { RankedTensorType input_type = - input_value.getType().dyn_cast(); + dyn_cast(input_value.getType()); if (!input_type) return std::nullopt; @@ -840,7 +840,7 @@ convertReduceSumOp(PatternRewriter &rewriter, Operation *op, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, bool keep_dims) { RankedTensorType input_type = - input_value.getType().dyn_cast(); + dyn_cast(input_value.getType()); if (!input_type) return std::nullopt; @@ -863,9 +863,9 @@ convertReduceSumOp(PatternRewriter &rewriter, Operation *op, if (input_is_qtype) { auto input_qtype = - input_type.getElementType().cast(); + cast(input_type.getElementType()); auto output_qtype = - output_type.getElementType().cast(); + cast(output_type.getElementType()); int32_t input_shift = 20; @@ -895,7 +895,7 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op, // op2 = mul(op1, 1.0 / num_elements_on_reduced_axis) RankedTensorType input_type = - input_value.getType().dyn_cast(); + dyn_cast(input_value.getType()); if (!input_type) return std::nullopt; @@ -940,9 +940,9 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op, if (input_is_qtype) { auto input_qtype = - input_type.getElementType().cast(); + cast(input_type.getElementType()); auto output_qtype = - output_type.getElementType().cast(); + cast(output_type.getElementType()); // Combine 'div_scale' as part of output rescale output_scale = div_scale * input_qtype.getScale() / output_qtype.getScale(); @@ -976,7 +976,7 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, bool keep_dims) { RankedTensorType input_type = - input_value.getType().dyn_cast(); + dyn_cast(input_value.getType()); if (!input_type) return std::nullopt; diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 4fabe9f53caf..5c46b8942fdd 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -45,7 +45,7 @@ Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op, Value input_val, double input_scale, int64_t input_zp) { // Output is always int32 type - auto input_type = input_val.getType().dyn_cast(); + auto input_type = dyn_cast(input_val.getType()); assert(input_type); auto output_type = input_type.clone(rewriter.getI32Type()); @@ -58,9 +58,9 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op, Value conv_val, ShapedType input_type, ShapedType weight_type, ShapedType output_type) { auto input_qtype = - input_type.getElementType().dyn_cast(); - auto output_qtype = output_type.getElementType() - .dyn_cast(); + dyn_cast(input_type.getElementType()); + auto output_qtype = + dyn_cast(output_type.getElementType()); double input_scale = input_qtype.getScale(); @@ -71,8 +71,8 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op, int32_t scale_width = scale32 ? 32 : 16; if (auto weight_per_tensor_qtype = - weight_type.getElementType() - .dyn_cast()) { + dyn_cast( + weight_type.getElementType())) { // Per-tensor quantization double weight_scale = weight_per_tensor_qtype.getScale(); @@ -94,8 +94,8 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op, return rescale_op.getResult(); } else if (auto weight_per_channel_qtype = - weight_type.getElementType() - .dyn_cast()) { + dyn_cast( + weight_type.getElementType())) { // Per-channel quantization SmallVector multiplier_arr; SmallVector shift_arr; @@ -311,7 +311,7 @@ static LogicalResult checkValidityOfCast(Type src, Type dest) { LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op, Value src, Type destType, Value &result) { - Type srcElemTy = src.getType().dyn_cast().getElementType(); + Type srcElemTy = dyn_cast(src.getType()).getElementType(); Type destElemTy = dyn_cast(destType).getElementType(); if (failed(checkValidityOfCast(srcElemTy, destElemTy))) @@ -319,7 +319,7 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op, op, "casting to result dtype is invalid or unsupported"); if (destElemTy.isInteger(1)) { - auto srcType = src.getType().dyn_cast(); + auto srcType = dyn_cast(src.getType()); SmallVector srcShape(srcType.getShape()); uint64_t num_total_elements = 1; for (int64_t a : srcShape) @@ -355,7 +355,7 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op, Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) { Operation *op = input.getDefiningOp(); - TensorType inType = input.getType().cast(); + TensorType inType = cast(input.getType()); if (inType.getElementType() != outType.getElementType()) { TensorType promotedType = diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index 60e888367d5a..bae25cc7ac60 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -52,7 +52,7 @@ LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op, Value v) { // Generate IR: dim = dim >= 0 ? dim : dim + inputRank Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim, Value inputRank) { - assert(dim.getType().isa() && + assert(isa(dim.getType()) && "dim arg of toPositiveDim must be integer type"); Value dimAddInputRank = b.create(loc, dim, inputRank); Value cst0 = @@ -132,7 +132,7 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes, Type elemTy) { Value initTensor = b.create(loc, getAsOpFoldResult(sizes), elemTy); - RankedTensorType type = initTensor.getType().cast(); + RankedTensorType type = cast(initTensor.getType()); Value c0 = b.create(loc, b.getZeroAttr(type.getElementType())); return b.create(loc, c0, initTensor).getResult(0); @@ -172,7 +172,7 @@ Value getDimOp(OpBuilder &b, Location loc, Value v, int dim) { SmallVector getTensorSizesUntilDim(OpBuilder &b, Location loc, Value tensor, int dim) { - RankedTensorType type = tensor.getType().cast(); + RankedTensorType type = cast(tensor.getType()); assert(dim < type.getRank() && "The given dim must be smaller than tensor rank"); (void)type; @@ -183,7 +183,7 @@ SmallVector getTensorSizesUntilDim(OpBuilder &b, Location loc, } SmallVector getTensorSizes(OpBuilder &b, Location loc, Value tensor) { - RankedTensorType type = tensor.getType().cast(); + RankedTensorType type = cast(tensor.getType()); return getTensorSizesUntilDim(b, loc, tensor, type.getRank() - 1); } diff --git a/lib/Dialect/TMTensor/IR/TMTensorOps.cpp b/lib/Dialect/TMTensor/IR/TMTensorOps.cpp index 7b8a17682a9e..be07ca276dd6 100644 --- a/lib/Dialect/TMTensor/IR/TMTensorOps.cpp +++ b/lib/Dialect/TMTensor/IR/TMTensorOps.cpp @@ -77,7 +77,7 @@ Value TMTensor::getDimValue(OpBuilder &builder, Location loc, Value v, OpFoldResult TMTensor::getDim(OpBuilder &builder, Location loc, Value v, int64_t dim) { - auto t = v.getType().cast(); + auto t = cast(v.getType()); if (t.isDynamicDim(dim)) { return getDimValue(builder, loc, v, dim); } @@ -123,7 +123,7 @@ bool AttentionOp::payloadUsesValueFromOperand(OpOperand *opOperand) { static void matmul(OpBuilder &b, Location loc, Value lhs, ValueRange lhsSizes, Value rhs, ValueRange rhsSizes, Value output, ValueRange outputSizes, bool transposed = false) { - auto elementType = lhs.getType().cast().getElementType(); + auto elementType = cast(lhs.getType()).getElementType(); Value one = b.create(loc, 1); Value zero = b.create(loc, 0); auto rank = outputSizes.size(); @@ -168,9 +168,9 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, Value key = getKey(); Value value = getValue(); Value output = getOutput(); - auto queryType = query.getType().cast(); - auto keyType = key.getType().cast(); - auto valueType = value.getType().cast(); + auto queryType = cast(query.getType()); + auto keyType = cast(key.getType()); + auto valueType = cast(value.getType()); auto queryRank = queryType.getRank(); auto keyRank = keyType.getRank(); auto valueRank = valueType.getRank(); @@ -330,12 +330,12 @@ LogicalResult ScanOp::verify() { if (getNumOutputs() != 2) { return emitOpError("expected two output operands"); } - if (!input().getType().isa()) { + if (!isa(input().getType())) { return emitOpError("expected first input element type to be shaped"); } - auto accumulatorType = accumulator().getType().cast(); - auto inputType = input().getType().cast(); - auto outputType = output().getType().cast(); + auto accumulatorType = cast(accumulator().getType()); + auto inputType = cast(input().getType()); + auto outputType = cast(output().getType()); ArrayRef inputShapes = inputType.getShape(); ArrayRef outputShapes = outputType.getShape(); if (accumulatorType.getElementType() != inputType.getElementType()) { @@ -706,7 +706,7 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b, loadIndices.push_back(Value()); // Populate with empty values. - auto originalTy = original().getType().cast(); + auto originalTy = cast(original().getType()); starts.resize(originalTy.getRank(), Value()); auto updateIvs = ivs.drop_front(1); @@ -797,7 +797,7 @@ LogicalResult SortOp::verify() { if (yieldOp.getNumOperands() != 1) { return op->emitOpError("should yield exactly one operand"); } - auto ty = yieldOp.getOperand(0).getType().dyn_cast(); + auto ty = dyn_cast(yieldOp.getOperand(0).getType()); if (!ty || ty.getWidth() != 1) { return op->emitOpError("should yield i1 type"); } diff --git a/lib/Dialect/TMTensor/Transforms/Bufferize.cpp b/lib/Dialect/TMTensor/Transforms/Bufferize.cpp index 8f34358b9c0f..6e5a6769a843 100644 --- a/lib/Dialect/TMTensor/Transforms/Bufferize.cpp +++ b/lib/Dialect/TMTensor/Transforms/Bufferize.cpp @@ -29,7 +29,7 @@ using namespace ::mlir; using namespace ::mlir::torch::TMTensor; static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { - auto memrefType = memref.getType().cast(); + auto memrefType = cast(memref.getType()); auto alloc = b.create( loc, memref::getMixedSizes(b, loc, memref), memrefType.getElementType()); b.create(loc, memref, alloc); diff --git a/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp b/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp index d8af2ef5c493..e31e606b253a 100644 --- a/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp +++ b/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp @@ -80,7 +80,7 @@ struct ScalarLoopOpInterfaceLowerToLoopsPattern : public RewritePattern { return failure(); } if (llvm::any_of(scalarLoopOp->getResults(), - [&](Value v) { return v.getType().isa(); })) { + [&](Value v) { return isa(v.getType()); })) { return rewriter.notifyMatchFailure( scalarLoopOp, "lower to loops needs to have tensor semantics"); } diff --git a/lib/Dialect/Torch/IR/TorchDialect.cpp b/lib/Dialect/Torch/IR/TorchDialect.cpp index d57b3e74198e..fdd9875229e8 100644 --- a/lib/Dialect/Torch/IR/TorchDialect.cpp +++ b/lib/Dialect/Torch/IR/TorchDialect.cpp @@ -122,14 +122,14 @@ LogicalResult TorchDialect::verifyRegionArgAttribute(Operation *op, auto func = dyn_cast(op); if (!func) return op->emitError() << "'torch.type_bound' must be attached to a func"; - TypeAttr attr = namedAttr.getValue().dyn_cast(); + TypeAttr attr = dyn_cast(namedAttr.getValue()); if (!attr) return op->emitError() << "'torch.type_bound' must be TypeAttr"; - auto type = attr.getValue().dyn_cast(); + auto type = dyn_cast(attr.getValue()); if (!type) return op->emitError() << "'torch.type_bound' must be of " "!torch.tensor/!torch.vtensor type"; - if (!func.getFunctionType().getInput(argIndex).isa()) + if (!isa(func.getFunctionType().getInput(argIndex))) return op->emitError() << "'torch.type_bound' must be attached to an " "argument of !torch.tensor/!torch.vtensor type"; return success(); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index e0a766f78467..33079e35fda1 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -75,7 +75,7 @@ Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder, Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc, BaseTensorType newType, Value tensor) { - auto originalType = tensor.getType().cast(); + auto originalType = cast(tensor.getType()); // Adjust the static information in the type to match between the original and // new types. if (!originalType.hasSameSizesAndDtype(newType)) { @@ -87,7 +87,7 @@ Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc, // up creating one op that converts between the value and non-value tensor // domains. If both the original and new types are both non-value tensors, // then we do the copy by going to a value tensor and back. - if (tensor.getType().isa()) + if (isa(tensor.getType())) tensor = builder.create(loc, tensor); if (isa(newType)) tensor = builder.create(loc, tensor); @@ -96,7 +96,7 @@ Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc, } bool mlir::torch::Torch::isListPotentiallyMutated(Value list) { - assert(list.getType().isa()); + assert(isa(list.getType())); return llvm::any_of(list.getUsers(), potentiallyMutatesListOperands); } @@ -148,8 +148,7 @@ static Value getScalarIntValue(Value input, Location loc, return nullptr; if (auto valueTensorLiteralOp = input.getDefiningOp()) { - auto val = valueTensorLiteralOp.getValue() - .cast() + auto val = cast(valueTensorLiteralOp.getValue()) .getSplatValue(); return rewriter.create( loc, rewriter.getI64IntegerAttr(val)); @@ -777,7 +776,7 @@ OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) { if (getOperand(0).getType() != getResult().getType()) return nullptr; - if (auto tensorType = getOperand(0).getType().dyn_cast()) { + if (auto tensorType = dyn_cast(getOperand(0).getType())) { if (tensorType.hasSizes() && tensorType.getSizes().size() == 0) return getOperand(0); } @@ -798,11 +797,11 @@ OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) { if (!matchPattern(getCopy(), m_TorchConstantBool(©Arg)) || copyArg) return nullptr; // The memory_format arg must be `none`. - if (!getMemoryFormat().getType().isa()) + if (!isa(getMemoryFormat().getType())) return nullptr; - auto inputType = getSelf().getType().cast(); - auto resType = getType().cast(); + auto inputType = cast(getSelf().getType()); + auto resType = cast(getType()); // If the types aren't equal, then we can't fold. if (inputType != resType) return nullptr; @@ -821,7 +820,7 @@ OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) { // The pin_memory arg should be either constant `False` or `none`. - if (!getPinMemory().getType().isa()) { + if (!isa(getPinMemory().getType())) { bool pinMemory; if (!matchPattern(getPinMemory(), m_TorchConstantBool(&pinMemory))) return nullptr; @@ -844,15 +843,15 @@ OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) { return nullptr; // The device arg must be `none`. - if (!getDevice().getType().isa()) + if (!isa(getDevice().getType())) return nullptr; // The memory_format arg must be `none`. - if (!getMemoryFormat().getType().isa()) + if (!isa(getMemoryFormat().getType())) return nullptr; - auto inputType = getSelf().getType().cast(); - auto resType = getType().cast(); + auto inputType = cast(getSelf().getType()); + auto resType = cast(getType()); // If the types aren't equal, then we can't fold. if (inputType != resType) return nullptr; @@ -863,7 +862,7 @@ OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) { return nullptr; // The layout arg should be either `none` or `0` i.e. strided. - if (!getLayout().getType().isa()) { + if (!isa(getLayout().getType())) { int64_t tensorLayout; if (!matchPattern(getLayout(), m_TorchConstantInt(&tensorLayout))) return nullptr; @@ -882,7 +881,7 @@ void AtenToDtypeLayoutOp::getCanonicalizationPatterns( // is false patterns.add(+[](AtenToDtypeLayoutOp op, PatternRewriter &rewriter) { // The pin_memory arg should be either constant `False` or `none`. - if (!op.getPinMemory().getType().isa()) { + if (!isa(op.getPinMemory().getType())) { bool pinMemory; if (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory))) return failure(); @@ -891,7 +890,7 @@ void AtenToDtypeLayoutOp::getCanonicalizationPatterns( } // The layout arg should be either `none` or `0` i.e. strided. - if (!op.getLayout().getType().isa()) { + if (!isa(op.getLayout().getType())) { int64_t tensorLayout; if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout))) return failure(); @@ -899,7 +898,7 @@ void AtenToDtypeLayoutOp::getCanonicalizationPatterns( return failure(); } - if (op.getDevice().getType().isa()) { + if (isa(op.getDevice().getType())) { // The device arg is `none`. Rewrite to to.dtype. AtenToDtypeOp toDtype = rewriter.create( op.getLoc(), op.getType(), op.getSelf(), op.getDtype(), @@ -985,10 +984,10 @@ void Aten_CastLongOp::getCanonicalizationPatterns(RewritePatternSet &patterns, //===----------------------------------------------------------------------===// OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) { - auto inputType = getOperand(0).getType().dyn_cast(); + auto inputType = dyn_cast(getOperand(0).getType()); if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1) return nullptr; - auto resType = getType().dyn_cast(); + auto resType = dyn_cast(getType()); if (!resType || !resType.hasSizes() || resType.getSizes().size() != 1) return nullptr; if (inputType != resType) @@ -1011,7 +1010,7 @@ OpFoldResult PrimsViewOfOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult AtenDimOp::fold(FoldAdaptor adaptor) { - if (auto tensorType = getOperand().getType().dyn_cast()) { + if (auto tensorType = dyn_cast(getOperand().getType())) { if (tensorType.hasSizes()) return IntegerAttr::get(IntegerType::get(getContext(), 64), tensorType.getSizes().size()); @@ -1117,7 +1116,7 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op, } if (isa(op)) { - if (op->getOperand(2).getType().isa()) { + if (isa(op->getOperand(2).getType())) { // None rounding mode Value quotient = rewriter.create(loc, lhs, rhs); rewriter.replaceOpWithNewOp(op, outType, @@ -1879,9 +1878,9 @@ OpFoldResult AtenLogOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult AtenFloorOp::fold(FoldAdaptor adaptor) { - auto resultType = getType().dyn_cast(); + auto resultType = dyn_cast(getType()); if (resultType && resultType.hasDtype() && - resultType.getDtype().isa()) { + isa(resultType.getDtype())) { return getSelf(); } return {}; @@ -1892,9 +1891,9 @@ OpFoldResult AtenFloorOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult AtenCeilOp::fold(FoldAdaptor adaptor) { - auto resultType = getType().dyn_cast(); + auto resultType = dyn_cast(getType()); if (resultType && resultType.hasDtype() && - resultType.getDtype().isa()) { + isa(resultType.getDtype())) { return getSelf(); } return {}; @@ -1905,9 +1904,9 @@ OpFoldResult AtenCeilOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) { - auto resultType = getType().dyn_cast(); + auto resultType = dyn_cast(getType()); if (resultType && resultType.hasDtype() && - resultType.getDtype().isa()) { + isa(resultType.getDtype())) { return getSelf(); } return {}; @@ -1918,7 +1917,7 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult AtenTruncOp::fold(FoldAdaptor adaptor) { - auto resultType = getType().dyn_cast(); + auto resultType = dyn_cast(getType()); if (resultType && resultType.hasDtype() && resultType.getDtype().isa()) { return getSelf(); @@ -1987,7 +1986,7 @@ void AtenDivScalarModeOp::getCanonicalizationPatterns( void AtenNumelOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenNumelOp op, PatternRewriter &rewriter) { - auto inputType = op.getSelf().getType().dyn_cast(); + auto inputType = dyn_cast(op.getSelf().getType()); if (!inputType || !inputType.areAllSizesKnown()) { return failure(); } @@ -2113,7 +2112,7 @@ traceKnownSizeTensorType(Value value, std::optional dim) { if (!value || !value.getType().isa()) return failure(); - auto tensorType = value.getType().cast(); + auto tensorType = cast(value.getType()); if (foundType(tensorType, dim)) return tensorType; @@ -2649,7 +2648,7 @@ LogicalResult NonValueTensorLiteralOp::inferReturnTypes( .dyn_cast_or_null(); if (!attr) return failure(); - RankedTensorType tensorType = attr.getType().cast(); + RankedTensorType tensorType = cast(attr.getType()); NonValueTensorType returnType = NonValueTensorType::get(tensorType.getContext(), tensorType.getShape(), tensorType.getElementType()); @@ -2691,7 +2690,7 @@ LogicalResult ValueTensorLiteralOp::inferReturnTypes( .dyn_cast_or_null(); if (!attr) return failure(); - RankedTensorType tensorType = attr.getType().cast(); + RankedTensorType tensorType = cast(attr.getType()); ValueTensorType returnType = ValueTensorType::get(tensorType.getContext(), tensorType.getShape(), tensorType.getElementType()); @@ -2751,8 +2750,8 @@ void TensorStaticInfoCastOp::getCanonicalizationPatterns( //===----------------------------------------------------------------------===// LogicalResult CopyToNonValueTensorOp::verify() { - auto resultType = getResult().getType().cast(); - auto operandType = getOperand().getType().cast(); + auto resultType = cast(getResult().getType()); + auto operandType = cast(getOperand().getType()); if (!resultType.hasSameSizesAndDtype(operandType)) return emitError() << "operand and result must have same sizes and dtype"; return success(); @@ -2762,7 +2761,7 @@ LogicalResult CopyToNonValueTensorOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - auto resultType = operands[0].getType().cast(); + auto resultType = cast(operands[0].getType()); inferredReturnTypes.push_back(resultType.getWithoutValueSemantics()); return success(); } @@ -2778,8 +2777,8 @@ void CopyToNonValueTensorOp::getEffects( //===----------------------------------------------------------------------===// LogicalResult CopyToValueTensorOp::verify() { - auto resultType = getResult().getType().cast(); - auto operandType = getOperand().getType().cast(); + auto resultType = cast(getResult().getType()); + auto operandType = cast(getOperand().getType()); if (!resultType.hasSameSizesAndDtype(operandType)) return emitError() << "operand and result must have same sizes and dtype"; return success(); @@ -2789,7 +2788,7 @@ LogicalResult CopyToValueTensorOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - auto resultType = operands[0].getType().cast(); + auto resultType = cast(operands[0].getType()); inferredReturnTypes.push_back(resultType.getWithValueSemantics()); return success(); } @@ -3004,7 +3003,7 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns( //===----------------------------------------------------------------------===// OpFoldResult AtenIsFloatingPointOp::fold(FoldAdaptor adaptor) { - auto operandType = getSelf().getType().dyn_cast(); + auto operandType = dyn_cast(getSelf().getType()); if (!operandType) return nullptr; if (operandType.hasDtype()) { @@ -3493,8 +3492,8 @@ void AtenCatOp::getCanonicalizationPatterns(RewritePatternSet &patterns, //===----------------------------------------------------------------------===// OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) { - auto inType = getOperand(0).getType().dyn_cast(); - auto outType = getResult().getType().dyn_cast(); + auto inType = dyn_cast(getOperand(0).getType()); + auto outType = dyn_cast(getResult().getType()); if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() || !outType.hasDtype()) return nullptr; @@ -3534,8 +3533,8 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { IntegerAttr end = dyn_cast_or_null(adaptor.getEnd()); IntegerAttr step = dyn_cast_or_null(adaptor.getStep()); IntegerAttr dim = dyn_cast_or_null(adaptor.getDim()); - auto inType = getOperand(0).getType().dyn_cast(); - auto outType = getResult().getType().dyn_cast(); + auto inType = dyn_cast(getOperand(0).getType()); + auto outType = dyn_cast(getResult().getType()); if (start && end && step && step.getValue().getSExtValue() == 1 && start.getValue().getSExtValue() == 0 && @@ -3793,7 +3792,7 @@ OpFoldResult AtenSqrtIntOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult PrimDtypeOp::fold(FoldAdaptor adaptor) { - BaseTensorType tensorType = getA().getType().cast(); + BaseTensorType tensorType = cast(getA().getType()); if (tensorType.hasDtype()) { torch_upstream::ScalarType scalarType = Torch::getScalarTypeForType(tensorType.getDtype()); @@ -4568,7 +4567,7 @@ LogicalResult AtenNormScalarOp::verify() { // Per PyTorch docs, only float and complex types are valid for norm // operation. - auto inTensor = getSelf().getType().cast(); + auto inTensor = cast(getSelf().getType()); // If no dtype is specified, it will default to a float one. if (!inTensor.hasDtype()) { @@ -4605,8 +4604,8 @@ LogicalResult AtenPermuteOp::verify() { return success(); } - auto outType = getResult().getType().cast(); - auto inType = getSelf().getType().cast(); + auto outType = cast(getResult().getType()); + auto inType = cast(getSelf().getType()); if (!outType.hasSizes() || !inType.hasSizes()) { return success(); @@ -4689,8 +4688,8 @@ LogicalResult AtenPermuteOp::verify() { LogicalResult AtenLinalgCrossOp::verify() { - auto selfType = getSelf().getType().cast(); - auto otherType = getOther().getType().cast(); + auto selfType = cast(getSelf().getType()); + auto otherType = cast(getOther().getType()); if (!selfType.hasDtype() || !otherType.hasDtype() || !selfType.hasSizes() || !otherType.hasSizes()) { @@ -4857,7 +4856,7 @@ LogicalResult GlobalSlotModuleInitializerOp::verify() { // Check that initial values satisfy type bounds. for (int i = 0, e = initialize.getNumOperands(); i < e; ++i) { - auto symName = initialize.getSlotSymNames()[i].cast(); + auto symName = cast(initialize.getSlotSymNames()[i]); auto initialValue = initialize.getOperand(i); auto globalSlotOp = symbolTable.lookup(symName.getValue()); if (!isValidSubtype(initialValue.getType(), globalSlotOp.getTypeBound())) { diff --git a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp index 000efbc7ceb1..6eb949e589c6 100644 --- a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp +++ b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp @@ -49,7 +49,7 @@ class AdjustCallingConventionForFunc // The incoporation of the torch.type_bound arg attr is context-dependent. for (auto type : llvm::enumerate(func.getArgumentTypes())) { - if (type.value().isa()) { + if (isa(type.value())) { auto typeBoundAttr = func.getArgAttrOfType(type.index(), typeBoundIdent); Type bound = typeBoundAttr ? typeBoundAttr.getValue() : Type(); @@ -61,7 +61,7 @@ class AdjustCallingConventionForFunc ? typeBoundAttr.getValue() : type.value()); continue; - } else if (auto none = type.value().dyn_cast()) { + } else if (auto none = dyn_cast(type.value())) { continue; } // TODO: add tuple type. @@ -111,7 +111,7 @@ class AdjustCallingConventionForCall SmallVector newOperands; for (auto operand : llvm::enumerate(adaptor.getOperands())) { - if (operand.value().getType().isa()) + if (isa(operand.value().getType())) continue; auto it = typeBoundMap.find({call.getCallee(), operand.index()}); if (it != typeBoundMap.end()) { @@ -167,9 +167,9 @@ class AdjustCallingConventionForReturn for (auto operand : adaptor.getOperands()) { if (!operand) continue; - if (operand.getType().isa()) + if (isa(operand.getType())) continue; - if (auto tuple = operand.getType().dyn_cast()) { + if (auto tuple = dyn_cast(operand.getType())) { Location loc = op.getLoc(); for (auto en : llvm::enumerate(tuple.getContainedTypes())) { auto i = rewriter.create( @@ -207,7 +207,7 @@ static LogicalResult adjustCallingConventions(func::FuncOp func, [](OpBuilder &builder, Torch::BaseTensorType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); - assert(inputs[0].getType().isa()); + assert(isa(inputs[0].getType())); return copyTensorToType(builder, loc, type, inputs[0]); }); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 49dd5319514b..62d5f7335db8 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -29,7 +29,7 @@ using namespace mlir::torch::Torch; // Helper function to check whether the `dtype` is None or Float type. static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) { - if (dtype.getType().isa()) + if (isa(dtype.getType())) return true; int64_t dtypeInt; if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) @@ -87,7 +87,7 @@ static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc, Value keepDimCst = rewriter.create(loc, keepDim); Value dtype = rewriter.create(loc); Type resultType = computeReductionType( - rewriter, op, input.getType().cast(), dim, keepDim); + rewriter, op, cast(input.getType()), dim, keepDim); if (!resultType) return nullptr; return rewriter.create(loc, resultType, input, dimList, @@ -100,7 +100,7 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc, bool keepDim) { Value keepDimCst = rewriter.create(loc, keepDim); BaseTensorType valueType = - computeReductionType(rewriter, op, input.getType().cast(), + computeReductionType(rewriter, op, cast(input.getType()), dim, keepDim) .cast(); if (!valueType) @@ -296,7 +296,7 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc, int64_t contractingDimsLength, int64_t otherDimsLength, int64_t reduceDimsLength, bool isLhs) { - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength + reduceDimsLength; SmallVector inputShapeTensor; @@ -415,7 +415,7 @@ static Value permuteTensorForMatmul(PatternRewriter &rewriter, Location loc, SmallVector &contractingDims, SmallVector &otherDims, SmallVector &reduceDims, bool isLhs) { - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); llvm::SmallDenseMap dimTokenMap; for (size_t idx = 0; idx < dimTokens.size(); ++idx) { dimTokenMap[dimTokens[idx]] = idx; @@ -451,8 +451,8 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc, Value &result, SmallVector &resultTokens, SmallVector &finalResultTokens) { - auto lhsType = lhs.getType().cast(); - auto rhsType = rhs.getType().cast(); + auto lhsType = cast(lhs.getType()); + auto rhsType = cast(rhs.getType()); Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype() : rhsType.getOptionalDtype(); @@ -562,7 +562,7 @@ static Value performLastReduceAndPermute(PatternRewriter &rewriter, Value input, SmallVector &inputTokens, SmallVector &outTokens) { - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); llvm::SmallDenseSet outTokenSet(outTokens.begin(), outTokens.end()); SmallVector sumDims; @@ -643,7 +643,7 @@ class DecomposeAtenAmaxOp : public OpRewritePattern { op, "Expected a constant boolean value for keepDim"); Value input = op.getSelf(); - auto inputTy = input.getType().dyn_cast(); + auto inputTy = dyn_cast(input.getType()); if (!inputTy || !inputTy.hasSizes()) { return rewriter.notifyMatchFailure(op, "Expected input type having sizes"); @@ -677,7 +677,7 @@ class DecomposeAtenTriuOp : public OpRewritePattern { MLIRContext *context = op.getContext(); Location loc = op.getLoc(); Value input = op.getSelf(); - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); if (!inputType.hasSizes() || !inputType.hasDtype()) { return rewriter.notifyMatchFailure(op, "should have shape and dtype"); } @@ -764,7 +764,7 @@ class DecomposeAtenSelectIntOp : public OpRewritePattern { Value dim = op.getDim(); Value self = op.getSelf(); - auto resultTy = op.getType().cast(); + auto resultTy = cast(op.getType()); if (!resultTy.hasSizes() || !resultTy.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have sizes and dtype"); @@ -785,8 +785,8 @@ class DecomposeAtenSelectIntOp : public OpRewritePattern { rewriter.create(loc, one.getType(), start, one); Value slice = rewriter.create( loc, - computeReductionType(rewriter, op, - self.getType().cast(), dim, + computeReductionType(rewriter, op, cast(self.getType()), + dim, /*keepDim=*/true), op.getSelf(), dim, start, startPlusOne, /*step=*/one); @@ -988,7 +988,7 @@ class DecomposeAtenGluOp : public OpRewritePattern { Value self = op.getSelf(); Value dim = op.getDim(); - auto outputTy = op.getType().dyn_cast(); + auto outputTy = dyn_cast(op.getType()); if (!outputTy || !outputTy.hasSizes() || !outputTy.hasDtype()) { return rewriter.notifyMatchFailure( op, "Expected output type having sizes and dtype"); @@ -1069,7 +1069,7 @@ class DecomposeAtenEyeMOp : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "unimplemented: m must be constant"); Value none = rewriter.create(loc); - auto outType = op.getType().dyn_cast(); + auto outType = dyn_cast(op.getType()); if (!outType) return rewriter.notifyMatchFailure( op, "Only tensor types input are currently supported"); @@ -1111,13 +1111,13 @@ class DecomposeAtenEyeMOp : public OpRewritePattern { // compare unsqueezed input with boundaries auto eqType = ValueTensorType::get( - context, op.getType().cast().getSizes(), + context, cast(op.getType()).getSizes(), IntegerType::get(context, 1)); Value eqTensor = rewriter.create(loc, eqType, unsqzRangeN, rangeM); Value dtype = op.getDtype(); - if (dtype.getType().isa()) { + if (isa(dtype.getType())) { rewriter.replaceOp(op, eqTensor); return success(); } else { @@ -1210,7 +1210,7 @@ class DecomposeAtenReshapeOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Value input = op.getSelf(); // TODO: Handle non value tensor type operands. - if (!input.getType().isa()) { + if (!isa(input.getType())) { return rewriter.notifyMatchFailure( op, "unimplemented: only value tensor type operands are supported"); } @@ -1248,7 +1248,7 @@ class DecomposeAtenEinsumOp : public OpRewritePattern { } auto allTensorHasSizes = [](Value tensor) { - auto type = tensor.getType().dyn_cast(); + auto type = dyn_cast(tensor.getType()); if (!type || !type.hasSizes()) return false; return true; @@ -1267,7 +1267,7 @@ class DecomposeAtenEinsumOp : public OpRewritePattern { if (equation.find("...") != std::string::npos) { SmallVector inputRanks; for (Value tensor : inputTensors) { - auto type = tensor.getType().cast(); + auto type = cast(tensor.getType()); inputRanks.push_back(type.getSizes().size()); } @@ -1332,10 +1332,10 @@ class DecomposeAtenTraceOp : public OpRewritePattern { rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value one = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - BaseTensorType inputType = self.getType().cast(); + BaseTensorType inputType = cast(self.getType()); Value output = op.getResult(); - BaseTensorType outputType = output.getType().cast(); + BaseTensorType outputType = cast(output.getType()); ArrayRef inputShape = inputType.getSizes(); int64_t diagonalSize = std::min(inputShape[0], inputShape[1]); @@ -1399,7 +1399,7 @@ class DecomposeAtenSoftmaxIntOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenSoftmaxIntOp op, PatternRewriter &rewriter) const override { Value self = op.getSelf(); - BaseTensorType resultTensorType = op.getType().cast(); + BaseTensorType resultTensorType = cast(op.getType()); if (!resultTensorType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); @@ -1410,7 +1410,7 @@ class DecomposeAtenSoftmaxIntOp : public OpRewritePattern { "Only support floating-point type"); // If `dtype` arg is non-none then convert the input to `dtype`. - if (!op.getDtype().getType().isa()) { + if (!isa(op.getDtype().getType())) { Location loc = op.getLoc(); Value none = rewriter.create(loc); Value cstFalse = rewriter.create(loc, false); @@ -1440,15 +1440,15 @@ class DecomposeAten_SoftmaxOp : public OpRewritePattern { LogicalResult matchAndRewrite(Aten_SoftmaxOp op, PatternRewriter &rewriter) const override { Value self = op.getSelf(); - BaseTensorType tensorType = self.getType().cast(); - if (!tensorType.hasDtype() || !tensorType.getDtype().isa()) + BaseTensorType tensorType = cast(self.getType()); + if (!tensorType.hasDtype() || !isa(tensorType.getDtype())) return rewriter.notifyMatchFailure(op, "Only support floating type"); bool halfToFloat; if (!matchPattern(op.getHalfToFloat(), m_TorchConstantBool(&halfToFloat))) return rewriter.notifyMatchFailure( op, "Expected a boolean value for half_to_float"); - BaseTensorType resultTensorType = op.getType().cast(); + BaseTensorType resultTensorType = cast(op.getType()); if (!resultTensorType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); @@ -1500,8 +1500,8 @@ class DecomposeAten_SoftmaxBackwardDataOp Value output = op.getOutput(); Value dim = op.getDim(); - BaseTensorType tensorType = gradOutput.getType().cast(); - if (!tensorType.hasDtype() || !tensorType.getDtype().isa()) + BaseTensorType tensorType = cast(gradOutput.getType()); + if (!tensorType.hasDtype() || !isa(tensorType.getDtype())) return rewriter.notifyMatchFailure(op, "Only support floating type"); Value newGrad = @@ -1536,8 +1536,8 @@ class DecomposeAtenTanhBackwardOp // Since, dTanh(x) = (1 - tanh(x)^2) hence, dOutput = (1 - output^2). Value output = op.getOutput(); - BaseTensorType tensorType = gradOutput.getType().cast(); - if (!tensorType.hasDtype() || !tensorType.getDtype().isa()) + BaseTensorType tensorType = cast(gradOutput.getType()); + if (!tensorType.hasDtype() || !isa(tensorType.getDtype())) return rewriter.notifyMatchFailure(op, "Only support floating type"); Value tanhSquare = @@ -1567,8 +1567,8 @@ class DecomposeAten_LogSoftmaxBackwardDataOp Value output = op.getOutput(); Value dim = op.getDim(); - BaseTensorType tensorType = gradOutput.getType().cast(); - if (!tensorType.hasDtype() || !tensorType.getDtype().isa()) + BaseTensorType tensorType = cast(gradOutput.getType()); + if (!tensorType.hasDtype() || !isa(tensorType.getDtype())) return rewriter.notifyMatchFailure(op, "Only support floating type"); Value expOut = rewriter.create(loc, tensorType, output); @@ -1650,8 +1650,8 @@ class DecomposeAtenArgMinMaxOp : public OpRewritePattern { Value keepDim = op.getKeepdim(); Value result = op.getResult(); - BaseTensorType inputType = input.getType().cast(); - BaseTensorType indicesTensorType = result.getType().cast(); + BaseTensorType inputType = cast(input.getType()); + BaseTensorType indicesTensorType = cast(result.getType()); std::optional maybeInputRank = getTensorRank(input); if (!maybeInputRank) { return rewriter.notifyMatchFailure( @@ -1670,7 +1670,7 @@ class DecomposeAtenArgMinMaxOp : public OpRewritePattern { // `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so // first the input tensor is flattened to 1d tensor and then the reduction // happens on the 0th dimension. - if (dim.getType().isa()) { + if (isa(dim.getType())) { BaseTensorType flattenType = inputType .getWithSizesAndDtype({kUnknownSize}, @@ -1720,7 +1720,7 @@ class DecomposeAtenBucketizeTensorOp Location loc = op.getLoc(); Value input = op.getSelf(); - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); if (!inputType.hasSizes()) { return rewriter.notifyMatchFailure( op, "unimplemented: input must have known sizes"); @@ -1728,7 +1728,7 @@ class DecomposeAtenBucketizeTensorOp ArrayRef inputShape = inputType.getSizes(); Value boundaries = op.getBoundaries(); - auto boundariesType = boundaries.getType().cast(); + auto boundariesType = cast(boundaries.getType()); if (!boundariesType.hasSizes() || boundariesType.getSizes().size() != 1) { return rewriter.notifyMatchFailure(op, "unimplemented: boundaries must have " @@ -1827,7 +1827,7 @@ static Value getLogSoftmaxResult(OpTy op, PatternRewriter &rewriter) { Location loc = op.getLoc(); Value dim = op.getDim(); Value self = op.getSelf(); - BaseTensorType tensorType = self.getType().cast(); + BaseTensorType tensorType = cast(self.getType()); Value xMax = createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true); if (!xMax) @@ -1856,12 +1856,12 @@ class DecomposeAtenLogSoftmaxIntOp LogicalResult matchAndRewrite(AtenLogSoftmaxIntOp op, PatternRewriter &rewriter) const override { Value self = op.getSelf(); - if (!op.getDtype().getType().isa()) + if (!isa(op.getDtype().getType())) return rewriter.notifyMatchFailure( op, "Unimplemented non-None dtype for log_softmax"); - BaseTensorType tensorType = self.getType().cast(); - if (!tensorType.hasDtype() || !tensorType.getDtype().isa()) + BaseTensorType tensorType = cast(self.getType()); + if (!tensorType.hasDtype() || !isa(tensorType.getDtype())) return rewriter.notifyMatchFailure(op, "Only support floating type"); Value logSoftmax = getLogSoftmaxResult(op, rewriter); @@ -1974,7 +1974,7 @@ class DecomposeAtenLinalgCrossOp : public OpRewritePattern { Type opType = op.getType(); Value dim = op.getDim(); - auto resType = self.getType().cast(); + auto resType = cast(self.getType()); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } @@ -2088,7 +2088,7 @@ class DecomposeAtenPixelShuffleOp Location loc = op.getLoc(); Value inValue = op.getSelf(); - auto inType = inValue.getType().cast(); + auto inType = cast(inValue.getType()); auto maybeSizes = inType.getOptionalSizes(); if (!maybeSizes) { return rewriter.notifyMatchFailure( @@ -2234,7 +2234,7 @@ class DecomposeAtenPixelShuffleOp // ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6) static Value getRelu6Results(PatternRewriter &rewriter, Location loc, Value input) { - BaseTensorType inputType = input.getType().cast(); + BaseTensorType inputType = cast(input.getType()); Value relu = rewriter.create(loc, inputType, input); Value cst6 = @@ -2252,7 +2252,7 @@ class DecomposeAtenRelu6Op : public OpRewritePattern { LogicalResult matchAndRewrite(AtenRelu6Op op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto resType = op.getType().cast(); + auto resType = cast(op.getType()); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } @@ -2304,7 +2304,7 @@ class DecomposeAtenLeakyReluOp : public OpRewritePattern { Location loc = op.getLoc(); Value input = op.getSelf(); Value negativeSlope = op.getNegativeSlope(); - auto resType = op.getType().cast(); + auto resType = cast(op.getType()); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } @@ -2341,7 +2341,7 @@ class DecomposeAtenLeakyReluBackwardOp Value gradOutput = op.getGradOutput(); Value input = op.getSelf(); Value negativeSlope = op.getNegativeSlope(); - auto resType = op.getType().cast(); + auto resType = cast(op.getType()); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } @@ -2382,7 +2382,7 @@ class DecomposeAtenPreluOp : public OpRewritePattern { Location loc = op.getLoc(); Value input = op.getSelf(); Value weight = op.getWeight(); - auto resType = op.getType().cast(); + auto resType = cast(op.getType()); auto boolTensorType = rewriter.getType( resType.getOptionalSizes(), rewriter.getI1Type()); Value zero = @@ -2408,14 +2408,14 @@ class DecomposeAtenLerpScalarOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenLerpScalarOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto resType = op.getType().cast(); + auto resType = cast(op.getType()); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } Value cstOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); auto start = op.getSelf(); - auto inputType = start.getType().cast(); + auto inputType = cast(start.getType()); auto delta = rewriter.create(loc, inputType, op.getEnd(), start, cstOne); @@ -2442,7 +2442,7 @@ class DecomposeAtenEluOp : public OpRewritePattern { Value alpha = op.getAlpha(); Value scale = op.getScale(); Value inputScale = op.getInputScale(); - auto resType = op.getType().cast(); + auto resType = cast(op.getType()); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } @@ -2486,7 +2486,7 @@ class DecomposeAtenSeluOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); - auto resType = op.getType().cast(); + auto resType = cast(op.getType()); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } @@ -2578,7 +2578,7 @@ class DecomposeAtenStackOp : public OpRewritePattern { } // Ensure all tensors have known sizes for (Value tensor : tensors) { - BaseTensorType tensorType = tensor.getType().cast(); + BaseTensorType tensorType = cast(tensor.getType()); if (!tensorType.hasSizes()) { return rewriter.notifyMatchFailure( op, "unimplemented: one tensor does not have known sizes"); @@ -2596,8 +2596,9 @@ class DecomposeAtenStackOp : public OpRewritePattern { } Type listElemType = - op.getType().cast().getWithSizesAndDtype( - /*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr); + cast(op.getType()) + .getWithSizesAndDtype( + /*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr); Type listType = Torch::ListType::get(listElemType); Value unsqueezedTensorList = rewriter.create( op.getLoc(), listType, unsqueezedTensors); @@ -2635,7 +2636,7 @@ class DecomposeAtenRollOp : public OpRewritePattern { Value constOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); auto self = op.getSelf(); - auto selfTy = self.getType().cast(); + auto selfTy = cast(self.getType()); // roll(input, shift, dim) = cat({ // slice(input, dim, -shift, none), // slice(input, dim, 0, -shift)}, dim) @@ -2817,7 +2818,7 @@ class DecomposeAtenRepeatInterleaveSelfIntOp if (!selfTy.hasSizes()) return rewriter.notifyMatchFailure( op, "Unimplemented: no implementation for rankless tensor"); - auto resType = op.getType().cast(); + auto resType = cast(op.getType()); if (!resType.hasSizes()) return rewriter.notifyMatchFailure( op, "Unimplemented: no implementation for rankless tensor"); @@ -2968,7 +2969,7 @@ class DecomposeAtenUnflattenIntOp Location loc = op.getLoc(); Value self = op.getSelf(); MLIRContext *context = op.getContext(); - BaseTensorType outputTensorType = op.getType().cast(); + BaseTensorType outputTensorType = cast(op.getType()); if (!outputTensorType.hasSizes()) return rewriter.notifyMatchFailure( op, "unimplemented: output must have known sizes"); @@ -2977,7 +2978,7 @@ class DecomposeAtenUnflattenIntOp if (!maybeRank) return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor"); unsigned inputRank = *maybeRank; - auto inputTensorType = self.getType().cast(); + auto inputTensorType = cast(self.getType()); if (!inputTensorType || !inputTensorType.hasSizes()) { return rewriter.notifyMatchFailure(op, "Expected input type having sizes"); @@ -3077,7 +3078,7 @@ class DecomposeAtenWhereScalarOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenWhereScalarOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto resType = op.getType().cast(); + auto resType = cast(op.getType()); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } @@ -3100,7 +3101,7 @@ class DecomposeAtenWhereScalarOtherOp LogicalResult matchAndRewrite(AtenWhereScalarOtherOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto resType = op.getType().cast(); + auto resType = cast(op.getType()); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } @@ -3122,7 +3123,7 @@ class DecomposeAtenWhereScalarSelfOp LogicalResult matchAndRewrite(AtenWhereScalarSelfOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto resType = op.getType().cast(); + auto resType = cast(op.getType()); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } @@ -3186,7 +3187,7 @@ class DecomposeAtenMaskedFillScalarOp LogicalResult matchAndRewrite(AtenMaskedFillScalarOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto resType = op.getType().cast(); + auto resType = cast(op.getType()); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } @@ -3227,7 +3228,7 @@ static LogicalResult createTorchTransposeOpForConvTbc(PatternRewriter &rewriter, int64_t dimB, Value &transposed) { Type transposedType; - if (failed(getTransposedType(input.getType().cast(), + if (failed(getTransposedType(cast(input.getType()), dimA, dimB, transposedType))) return failure(); Value cstDimA = rewriter.create( @@ -3578,7 +3579,7 @@ class DecomposeAtenConvolutionBackwardOp op.getGroups(), op.getDilation()); Type transposedType; - if (failed(getTransposedType(input.getType().cast(), 0, 1, + if (failed(getTransposedType(cast(input.getType()), 0, 1, transposedType))) return failure(); Value inputTransposed = rewriter.create( @@ -3605,7 +3606,7 @@ class DecomposeAtenConvolutionBackwardOp ValueRange{gradOutputViewDimZero, cstOne, gradOutputSize[2], gradOutputSize[3]}); - BaseTensorType gradOutputTy = gradOutput.getType().cast(); + BaseTensorType gradOutputTy = cast(gradOutput.getType()); if (!gradOutputTy.hasSizes()) return failure(); SmallVector gradOutputSizesInt(gradOutputTy.getSizes()); @@ -3625,7 +3626,7 @@ class DecomposeAtenConvolutionBackwardOp loc, gradOutputTypeForView, gradOutput, gradOutputViewShapeList); BaseTensorType inputTransposedTy = - inputTransposed.getType().cast(); + cast(inputTransposed.getType()); if (!inputTransposedTy.hasSizes()) return failure(); SmallVector inputTransposedSizesInt( @@ -3660,7 +3661,7 @@ class DecomposeAtenConvolutionBackwardOp /*dilation=*/op.getStride(), op.getTransposed(), op.getOutputPadding(), numGroup); - BaseTensorType weightTy = weight.getType().cast(); + BaseTensorType weightTy = cast(weight.getType()); if (!weightTy.hasSizes()) return failure(); SmallVector weightSizes(weightTy.getSizes()); @@ -3707,7 +3708,7 @@ class DecomposeAtenConvolutionBackwardOp gradWeight = rewriter.create( loc, gradWeightTypeForView, gradWeight, gradWeightViewShapeList); - gradWeightTy = gradWeight.getType().cast(); + gradWeightTy = cast(gradWeight.getType()); SmallVector gradWeightDimsOrder = computeDimsOrderForMoveDim(0, 2, gradWeightViewShapeInt.size()); SmallVector gradWeightMoveDimShape; @@ -3733,7 +3734,7 @@ class DecomposeAtenConvolutionBackwardOp /*keepdim=*/cstFalse, /*dtype=*/cstNone); } else { - if (failed(getTransposedType(gradOutput.getType().cast(), + if (failed(getTransposedType(cast(gradOutput.getType()), 0, 1, transposedType))) return failure(); Value gradOutputTransposed = rewriter.create( @@ -3792,7 +3793,7 @@ class DecomposeAtenAddmmOp : public OpRewritePattern { } // TODO: Handle integer type operands. - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); if (!inputType.hasDtype() || !inputType.getDtype().isa()) { return rewriter.notifyMatchFailure( op, "unimplemented: non-floating point dtype"); @@ -3821,7 +3822,7 @@ class DecomposeAtenMeanOp : public OpRewritePattern { Location loc = op.getLoc(); Value input = op.getSelf(); Value output = op.getResult(); - BaseTensorType outputTensorType = output.getType().cast(); + BaseTensorType outputTensorType = cast(output.getType()); Value sum = rewriter.create(loc, outputTensorType, input, op.getDtype()); Value numTensorElements = rewriter.create(loc, input); @@ -3854,7 +3855,7 @@ class DecomposeAtenMeanDimOp : public OpRewritePattern { Type outputType = op.getType(); MLIRContext *context = op.getContext(); - BaseTensorType inputType = input.getType().cast(); + BaseTensorType inputType = cast(input.getType()); if (!inputType.hasDtype() || !inputType.getDtype().isa() || !isNoneOrFloatDtype(context, dtype)) { return rewriter.notifyMatchFailure( @@ -3944,7 +3945,7 @@ class DecomposeAtenDropoutOp : public OpRewritePattern { rewriter.replaceOp(op, input); return success(); } - BaseTensorType inputType = input.getType().cast(); + BaseTensorType inputType = cast(input.getType()); if (!inputType.hasDtype() || !inputType.getDtype().isa()) return rewriter.notifyMatchFailure( op, "only support floating type input for training mode"); @@ -3992,7 +3993,7 @@ class DeomposeAtenNativeDropoutOp rewriter.replaceOp(op, ArrayRef{input, trueMask}); return success(); } - BaseTensorType inputType = input.getType().cast(); + BaseTensorType inputType = cast(input.getType()); if (!inputType.hasDtype() || !inputType.getDtype().isa()) { return rewriter.notifyMatchFailure( op, "only support floating type input for training mode"); @@ -4029,7 +4030,7 @@ class DecomposeAtenVarOp : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "expected input to have a rank"); } unsigned inputRank = *maybeInputRank; - BaseTensorType rank0FloatTensorTy = op.getType().cast(); + BaseTensorType rank0FloatTensorTy = cast(op.getType()); if (!rank0FloatTensorTy.hasSizes() || rank0FloatTensorTy.getSizes().size() != 0) { return rewriter.notifyMatchFailure( @@ -4060,7 +4061,7 @@ class DecomposeAtenStdOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenStdOp op, PatternRewriter &rewriter) const override { Value self = op.getSelf(); - BaseTensorType inputTensorTy = self.getType().cast(); + BaseTensorType inputTensorTy = cast(self.getType()); if (!inputTensorTy.hasDtype() || !inputTensorTy.getDtype().isa()) { return rewriter.notifyMatchFailure(op, @@ -4084,7 +4085,7 @@ class DecomposeAtenSoftplusOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); - BaseTensorType inputType = input.getType().cast(); + BaseTensorType inputType = cast(input.getType()); Value inputTimesBeta = rewriter.create(loc, inputType, input, op.getBeta()); @@ -4116,7 +4117,7 @@ class DecomposeAtenStdDimOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenStdDimOp op, PatternRewriter &rewriter) const override { Value self = op.getSelf(); - BaseTensorType inputTensorType = self.getType().cast(); + BaseTensorType inputTensorType = cast(self.getType()); if (!inputTensorType.hasDtype() || !inputTensorType.getDtype().isa()) { return rewriter.notifyMatchFailure( @@ -4141,7 +4142,7 @@ class DecomposeAtenStdCorrectionOp LogicalResult matchAndRewrite(AtenStdCorrectionOp op, PatternRewriter &rewriter) const override { Value self = op.getSelf(); - BaseTensorType inputTensorType = self.getType().cast(); + BaseTensorType inputTensorType = cast(self.getType()); if (!inputTensorType.hasDtype() || !inputTensorType.getDtype().isa()) { return rewriter.notifyMatchFailure( @@ -4167,8 +4168,8 @@ class DecomposeAtenHardsigmoidOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); - BaseTensorType inputType = input.getType().cast(); - auto resType = op.getType().cast(); + BaseTensorType inputType = cast(input.getType()); + auto resType = cast(op.getType()); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } @@ -4208,8 +4209,8 @@ class DecomposeAtenHardtanhOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); - BaseTensorType inputType = input.getType().cast(); - auto resType = op.getType().cast(); + BaseTensorType inputType = cast(input.getType()); + auto resType = cast(op.getType()); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } @@ -4235,7 +4236,7 @@ class DecomposeAtenRandLikeOp : public OpRewritePattern { Location loc = op.getLoc(); Value input = op.getSelf(); Type resultType = op.getType(); - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); if (!inputType.hasDtype() || !inputType.getDtype().isa()) { return rewriter.notifyMatchFailure(op, "only support floating-point type"); @@ -4268,8 +4269,8 @@ static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter, Operation *op, Location loc, Value input, Value prob, Value &output) { - auto inputType = input.getType().cast(); - auto probType = prob.getType().cast(); + auto inputType = cast(input.getType()); + auto probType = cast(prob.getType()); // Both the `input` and `prob` must be ranked tensors. if (!inputType.hasSizes() || !inputType.hasDtype() || !probType.hasSizes() || !probType.hasDtype()) { @@ -4338,12 +4339,12 @@ class DecomposeAtenBernoulliLikeOp : public OpRewritePattern { Location loc = op.getLoc(); Value input = op.getSelf(); Value p = op.getP(); - if (!op.getGenerator().getType().template isa()) + if (!isa(op.getGenerator().getType())) return rewriter.notifyMatchFailure( op, "The generator has to be None because only global default " "generator is supported"); - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); SmallVector empty; Type tensorType = inputType.getWithSizesAndDtype(llvm::ArrayRef(empty), rewriter.getF64Type()); @@ -4485,7 +4486,7 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto input = op.getInput().getType().cast(); + auto input = cast(op.getInput().getType()); if (!input.hasSizes()) return rewriter.notifyMatchFailure( op, "input tensor should have known sizes."); @@ -4518,7 +4519,7 @@ class DecomposeAtenInstanceNormOp Location loc = op.getLoc(); auto context = op.getContext(); - auto inputTy = op.getInput().getType().cast(); + auto inputTy = cast(op.getInput().getType()); int64_t inputRank = inputTy.getSizes().size(); SmallVector reducedShape(inputTy.getSizes()); SmallVector reduceDimInts; @@ -4583,7 +4584,7 @@ class DecomposeAtenInstanceNormOp loc, op.getResult().getType(), inputNormalized); Value weight = op.getWeight(); - auto weightTy = weight.getType().cast(); + auto weightTy = cast(weight.getType()); dtype = weightTy.getOptionalDtype(); SmallVector weightShape(weightTy.getSizes()); @@ -4610,7 +4611,7 @@ class DecomposeAtenInstanceNormOp rewriter.create(loc, inputTy, weight, op.getInput()); Value bias = op.getBias(); - auto biasTy = bias.getType().cast(); + auto biasTy = cast(bias.getType()); dtype = biasTy.getOptionalDtype(); SmallVector biasShape(biasTy.getSizes()); @@ -4654,7 +4655,7 @@ class DecomposeAtenNativeLayerNormOp Location loc = op.getLoc(); auto context = op.getContext(); - auto inputTy = op.getInput().getType().cast(); + auto inputTy = cast(op.getInput().getType()); if (!inputTy.hasSizes()) return rewriter.notifyMatchFailure( op, "input tensor should have known sizes."); @@ -4889,10 +4890,10 @@ class DecomposeAtenNativeGroupNormOp Value eps = op.getEps(); // Check the rank of the input/outputs tensor. - auto inputType = input.getType().cast(); - auto outputType = op.getResult0().getType().cast(); - auto meanType = op.getResult1().getType().cast(); - auto rsqrtVarType = op.getResult2().getType().cast(); + auto inputType = cast(input.getType()); + auto outputType = cast(op.getResult0().getType()); + auto meanType = cast(op.getResult1().getType()); + auto rsqrtVarType = cast(op.getResult2().getType()); if (!inputType.hasSizes() || !outputType.hasSizes() || !meanType.hasSizes() || !rsqrtVarType.hasSizes()) { return rewriter.notifyMatchFailure( @@ -5059,8 +5060,8 @@ class DecomposeAtenNativeBatchNormOp SmallVector runningStatsShapeInt(inputRank, 1); runningStatsShapeInt[1] = - runningMean.getType().cast().getSizes()[0]; - Type dtype = input.getType().cast().getOptionalDtype(); + cast(runningMean.getType()).getSizes()[0]; + Type dtype = cast(input.getType()).getOptionalDtype(); Type reshapeType = ValueTensorType::get( context, llvm::ArrayRef(runningStatsShapeInt), dtype); @@ -5175,8 +5176,7 @@ class DecomposeConstantTensorNewLikeOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Value dtype = op.getDtype(); if (dtype.getType().isa()) { - BaseTensorType tensorType = - op.getSelf().getType().template cast(); + BaseTensorType tensorType = cast(op.getSelf().getType()); if (!tensorType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected input tensor to have a dtype"); @@ -5200,7 +5200,7 @@ class DecomposeAtenFullOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenFullOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - BaseTensorType outTy = op.getType().template cast(); + BaseTensorType outTy = cast(op.getType()); if (!outTy.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); @@ -5231,12 +5231,12 @@ class DecomposeAtenLinearOp : public OpRewritePattern { Value weight = op.getWeight(); Value bias = op.getBias(); - BaseTensorType inputType = input.getType().cast(); + BaseTensorType inputType = cast(input.getType()); if (!inputType.hasSizes() || inputType.getSizes().size() < 2) return rewriter.notifyMatchFailure( op, "expected input to be rank 2 or greater"); - BaseTensorType weightType = weight.getType().cast(); + BaseTensorType weightType = cast(weight.getType()); // `weight` must be a rank 2 matrix. if (!weightType.hasSizes() || weightType.getSizes().size() != 2) return rewriter.notifyMatchFailure(op, "expected weight to be a rank 2"); @@ -5255,7 +5255,7 @@ class DecomposeAtenLinearOp : public OpRewritePattern { return success(); } - BaseTensorType biasType = bias.getType().cast(); + BaseTensorType biasType = cast(bias.getType()); if (!biasType.hasSizes() || biasType.getSizes().size() != 1) return rewriter.notifyMatchFailure(op, "expected bias to be rank 1"); @@ -5280,7 +5280,7 @@ class DecomposeAtenMishOp : public OpRewritePattern { Value input = op.getSelf(); Type type = op.getType(); - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); if (!inputType.hasDtype()) return rewriter.notifyMatchFailure(op, "Dtype not present"); @@ -5306,7 +5306,7 @@ class DecomposeAtenFullLikeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenFullLikeOp op, PatternRewriter &rewriter) const override { - BaseTensorType outTy = op.getType().template cast(); + BaseTensorType outTy = cast(op.getType()); if (!outTy.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); @@ -5335,7 +5335,7 @@ class DecomposeAtenNewFullOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Value dtype = op.getDtype(); if (dtype.getType().isa()) { - BaseTensorType tensorType = op.getSelf().getType().cast(); + BaseTensorType tensorType = cast(op.getSelf().getType()); if (!tensorType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected input tensor to have a dtype"); @@ -5393,7 +5393,7 @@ class DecomposeAten_ToCopyOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_ToCopyOp op, PatternRewriter &rewriter) const override { - auto resultType = op.getType().cast(); + auto resultType = cast(op.getType()); if (!resultType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); @@ -5419,12 +5419,12 @@ class DecomposeAtenCopyOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenCopyOp op, PatternRewriter &rewriter) const override { - auto resultType = op.getType().cast(); + auto resultType = cast(op.getType()); if (!resultType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); } - auto srcTy = op.getSrc().getType().cast(); + auto srcTy = cast(op.getSrc().getType()); if (!srcTy.hasSizes() || !srcTy.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected src type to have a known rank and dtype"); @@ -5448,7 +5448,7 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern { Value noneVal = rewriter.create(op.getLoc()); Value dtype = op.getDtype(); if (dtype.getType().isa()) { - BaseTensorType tensorType = op.getSelf().getType().cast(); + BaseTensorType tensorType = cast(op.getSelf().getType()); if (!tensorType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected input tensor to have a dtype"); @@ -5588,7 +5588,7 @@ class DecomposeAtenToPrimDeviceOp Value constNone = rewriter.create(loc); Value dtype = op.getDtype(); - if (dtype.getType().template isa()) { + if (isa(dtype.getType())) { dtype = rewriter.create(loc, op.getSelf()); } rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), @@ -5665,7 +5665,7 @@ class DecomposeAtenAdaptiveAvgPool1dOp SmallVector kernelSize; if (outputSizeInt == 1) { - BaseTensorType inputTensorType = input.getType().cast(); + BaseTensorType inputTensorType = cast(input.getType()); ArrayRef inputShape = inputTensorType.getSizes(); kernelSize.push_back( inputShape[rank - 1] == kUnknownSize @@ -5839,7 +5839,7 @@ class DecomposeAtenCosineSimilarityOp SmallVector indexBroadcastShapeValue; computeBroadcastShape(rewriter, loc, x1, x2, indexBroadcastShapeInt, indexBroadcastShapeValue); - Type dtype = x1.getType().cast().getOptionalDtype(); + Type dtype = cast(x1.getType()).getOptionalDtype(); Type broadcastType = ValueTensorType::get( op.getContext(), llvm::ArrayRef(indexBroadcastShapeInt), dtype); Value indexBroadcastShapeTorchList = rewriter.create( @@ -5925,9 +5925,9 @@ class DecomposeAtenBaddbmmOp : public OpRewritePattern { Value alphaTimesBmm = rewriter.create(loc, op.getType(), bmm, op.getAlpha()); Value input = op.getSelf(); - BaseTensorType inputType = input.getType().cast(); + BaseTensorType inputType = cast(input.getType()); BaseTensorType resultType = - op->getResult(0).getType().cast(); + cast(op->getResult(0).getType()); if (inputType.hasDtype() && resultType.hasDtype() && inputType.getDtype() != resultType.getDtype()) { input = convertTensorToDtype(rewriter, loc, input, resultType.getDtype()); @@ -6011,7 +6011,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, Value self = op.getSelf(); Value dimList = op.getDim(); Value keepDim = op.getKeepdim(); - BaseTensorType inputTensorTy = self.getType().cast(); + BaseTensorType inputTensorTy = cast(self.getType()); Type outputType = op.getType(); BaseTensorType outputTensorType = cast(outputType); if (!outputTensorType.hasDtype()) { @@ -6030,7 +6030,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, // computation of the result. if (inputTensorTy.getDtype().getIntOrFloatBitWidth() != 64) { self = convertTensorToDtype(rewriter, loc, self, rewriter.getF64Type()); - inputTensorTy = self.getType().cast(); + inputTensorTy = cast(self.getType()); } std::optional maybeInputRank = getTensorRank(self); @@ -6040,7 +6040,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, unsigned inputRank = *maybeInputRank; SmallVector dimListElements; bool isNoneOrEmpty = true; - if (!dimList.getType().template isa()) { + if (!isa(dimList.getType())) { if (!getListConstructElements(dimList, dimListElements)) return rewriter.notifyMatchFailure( op, "expect dimList to be constructed from list construct"); @@ -6287,8 +6287,8 @@ class DecomposeAtenMseLossOp : public OpRewritePattern { op, "Expected a constant integer value for reduction"); Location loc = op.getLoc(); - BaseTensorType resultType = op.getType().cast(); - BaseTensorType inputType = op.getSelf().getType().cast(); + BaseTensorType resultType = cast(op.getType()); + BaseTensorType inputType = cast(op.getSelf().getType()); if (!inputType.hasSizes()) return rewriter.notifyMatchFailure( op, "Expected the input tensor to have sizes"); @@ -6506,7 +6506,7 @@ class DecomposeAtenRandnGeneratorOp LogicalResult matchAndRewrite(AtenRandnGeneratorOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto resultType = op.getType().cast(); + auto resultType = cast(op.getType()); if (!resultType.hasDtype()) { return rewriter.notifyMatchFailure( @@ -6617,7 +6617,7 @@ class DecomposeAtenRandOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto resultType = op.getType().cast(); + auto resultType = cast(op.getType()); if (!resultType.hasDtype()) { return rewriter.notifyMatchFailure( @@ -6943,7 +6943,7 @@ class DecomposeAtenOneHotOp : public OpRewritePattern { auto context = op.getContext(); Value input = op.getSelf(); - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); if (!inputType.hasSizes()) return rewriter.notifyMatchFailure( op, "input tensor should have known sizes."); @@ -6974,7 +6974,7 @@ class DecomposeAtenOneHotOp : public OpRewritePattern { // compare auto eqType = ValueTensorType::get( - context, op.getType().cast().getSizes(), + context, cast(op.getType()).getSizes(), IntegerType::get(context, 1)); Value eqTensor = rewriter.create( loc, eqType, unsqueezeTensor, arangeTensor); @@ -7019,7 +7019,7 @@ class DecomposeAtenScalarTensor : public OpRewritePattern { LogicalResult matchAndRewrite(AtenScalarTensorOp op, PatternRewriter &rewriter) const override { - auto resultTy = op.getResult().getType().cast(); + auto resultTy = cast(op.getResult().getType()); auto scalarTy = getBuiltInTypeForTorchScalar(op.getS().getType()); Value numToTensor = rewriter.create( op.getLoc(), @@ -7060,7 +7060,7 @@ class DecomposeAtenTopkOp : public OpRewritePattern { Value self = op.getSelf(); Value dim = op.getDim(); - auto selfType = self.getType().cast(); + auto selfType = cast(self.getType()); auto sortIndicesType = selfType.getWithSizesAndDtype( selfType.getOptionalSizes(), IntegerType::get(context, 64, IntegerType::Signed)); @@ -7111,8 +7111,8 @@ class DecomposeAtenScatterValueOp Value sizeList = rewriter.create( loc, ListType::get(IntType::get(context)), sizes); - auto selfType = self.getType().cast(); - auto indexType = index.getType().cast(); + auto selfType = cast(self.getType()); + auto indexType = cast(index.getType()); BaseTensorType srcType = selfType .getWithSizesAndDtype(indexType.getOptionalSizes(), @@ -7135,7 +7135,7 @@ class DecomposeAtenSgnOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenSgnOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto outType = op.getType().cast(); + auto outType = cast(op.getType()); if (!outType.hasDtype()) { return rewriter.notifyMatchFailure(op, "expected result type to have dtype"); @@ -7273,14 +7273,14 @@ class DecomposeAtenIndexTensorOp : public OpRewritePattern { "failed to get elements of `indices`"); auto input = op.getSelf(); - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); if (!inputType.hasSizes()) { return rewriter.notifyMatchFailure( op, "only input with shape information is supported"); } auto inputSizes = inputType.getSizes(); int64_t inputRank = inputSizes.size(); - auto outputType = op.getType().cast(); + auto outputType = cast(op.getType()); if (!outputType.hasSizes()) { return rewriter.notifyMatchFailure( op, "only output with shape information is supported"); @@ -7438,7 +7438,7 @@ class DecomposeAtenTileOp : public OpRewritePattern { op, "failed to get elements of `dims` param"); } auto dimsSize = dimsElements.size(); - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); if (!inputType.hasSizes()) { return rewriter.notifyMatchFailure( op, "only support input tensor with shape information"); diff --git a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp index 3b30e9424f44..0c352d31ca80 100644 --- a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp @@ -89,7 +89,7 @@ class QuantizeTransposedOperands : public OpRewritePattern { .cast() .getOptionalDtype(); auto torchQType = - quant.getType().cast().getOptionalDtype(); + cast(quant.getType()).getOptionalDtype(); auto transQTy = rewriter.getType(trans.getResult() .getType() @@ -152,7 +152,7 @@ template class QuantizeBias : public OpRewritePattern { return failure(); Value bias = operands[2]; - auto biasTy = bias.getType().dyn_cast(); + auto biasTy = dyn_cast(bias.getType()); if (biasTy) { auto biasETy = biasTy.getOptionalDtype(); diff --git a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp index 239960629797..dbf203584601 100644 --- a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp +++ b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp @@ -134,7 +134,7 @@ class ObjectGraphInfo { slotName = setAttrOp.getName(); } - auto moduleType = module.getType().cast(); + auto moduleType = cast(module.getType()); auto slots = moduleClassNameToSlots.find(moduleType.getClassName()); // TODO: Improve verifier so that this can never happen if (slots == moduleClassNameToSlots.end()) @@ -163,13 +163,13 @@ class ObjectGraphInfo { } auto classType = symbolTable.lookup( - nnModule.getType().cast().getClassName()); + cast(nnModule.getType()).getClassName()); for (auto t : llvm::zip(nnModule.getOps(), classType.getOps())) { auto slot = std::get<0>(t); auto attr = std::get<1>(t); nameStack.push_back(attr.getName().str()); - if (attr.getType().isa()) { + if (isa(attr.getType())) { if (failed(recursivelyTraverse( slot.getValue().getDefiningOp()))) return failure(); @@ -333,7 +333,7 @@ static LogicalResult analyzeInstances(func::FuncOp func, for (auto &argInstance : argInstances) mapping.map(func.getArgument(argInstance.argIndex), argInstance.instance); auto walkResult = func.walk([&](PrimGetAttrOp op) { - if (!op.getType().isa()) + if (!isa(op.getType())) return WalkResult::advance(); auto instance = mapping.lookupOrNull(op.getReceiver()); assert(instance && "verifyFuncConformsToSubset should ensure this"); @@ -355,7 +355,7 @@ createMonomorphizationForCall(func::CallOp op, IRMapping &mapping, Monomorphization monomorphization; monomorphization.func = func; for (auto operand : llvm::enumerate(op->getOperands())) { - if (!operand.value().getType().isa()) + if (!isa(operand.value().getType())) continue; Value instance = mapping.lookupOrNull(operand.value()); assert(instance && "verifyFuncConformsToSubset should ensure this"); @@ -377,7 +377,7 @@ class MonomorphizationTracker { monomorphization.func = func; bool canTriviallyMonomorphize = true; for (auto arg : llvm::enumerate(func.getArguments())) { - auto type = arg.value().getType().dyn_cast(); + auto type = dyn_cast(arg.value().getType()); if (!type) continue; auto classType = symbolTable.lookup(type.getClassName()); @@ -436,7 +436,7 @@ class MonomorphizationTracker { // !torch.nn.Module<"..."> types. static LogicalResult verifyNnModuleValueUses(Value value) { // Trivially succeed for non-module types. - if (!value.getType().isa()) + if (!isa(value.getType())) return success(); for (Operation *op : value.getUsers()) { if (isa(op)) @@ -516,7 +516,7 @@ static LogicalResult rewriteMonomorphizedFuncClone( return WalkResult::advance(); }; auto handlePrimGetAttr = [&](PrimGetAttrOp op) { - if (!op.getType().isa()) { + if (!isa(op.getType())) { auto instance = mapping.lookup(op.getReceiver()).getDefiningOp(); SlotOp affectedSlot; @@ -540,7 +540,7 @@ static LogicalResult rewriteMonomorphizedFuncClone( Monomorphization monomorphization = std::move(*maybeMonomorphization); auto newArguments = llvm::to_vector<6>( llvm::make_filter_range(op->getOperands(), [](Value v) { - return !v.getType().isa(); + return !isa(v.getType()); })); assert(newFuncs.find(monomorphization) != newFuncs.end()); auto newOp = OpBuilder(op).create( @@ -564,7 +564,7 @@ static LogicalResult rewriteMonomorphizedFuncClone( } llvm::BitVector argsToErase(func.getNumArguments()); for (auto type : llvm::enumerate(func.getArgumentTypes())) { - if (type.value().isa()) { + if (isa(type.value())) { argsToErase.set(type.index()); } } diff --git a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp index 5d59dfd8c596..2aa9f42307b1 100644 --- a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp +++ b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp @@ -248,8 +248,8 @@ bool InlineGlobalSlotsAnalysis::isValueSafeTransferFunction(Value value) { })) continue; if (auto initialize = dyn_cast(op)) { - auto symName = initialize.getSlotSymNames()[use.getOperandNumber()] - .cast(); + auto symName = cast( + initialize.getSlotSymNames()[use.getOperandNumber()]); auto *state = getOrCreateFor( value, getProgramPoint(symName)); if (state->isSafe) @@ -333,10 +333,10 @@ class InlineGlobalSlotsPass DenseSet safeToInline; for (int i = 0, e = initialize->getNumOperands(); i != e; i++) { auto slotSymName = - initialize.getSlotSymNames()[i].cast(); + cast(initialize.getSlotSymNames()[i]); Value operand = initialize.getOperand(i); auto symbolRefPoint = solver.getProgramPoint( - initialize.getSlotSymNames()[i].cast()); + cast(initialize.getSlotSymNames()[i])); auto *state = solver.lookupState(symbolRefPoint); // We roll the analysis of whether a slot is set or public into the @@ -408,7 +408,7 @@ class InlineGlobalSlotsPass SmallVector newInitialValues; for (int i = 0, e = initialize.getNumOperands(); i != e; i++) { auto slotSymName = - initialize.getSlotSymNames()[i].cast(); + cast(initialize.getSlotSymNames()[i]); if (!safeToInline.count(slotSymName)) { newSlotSymNames.push_back(slotSymName); newInitialValues.push_back(initialize.getOperand(i)); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index e1377afce373..b3318c6c1c72 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -118,7 +118,7 @@ static LogicalResult checkType(Operation *op, Type type, if (auto optionalType = dyn_cast(type)) { // TODO: Be stricter about tensor types. // See comment below for ListType. - if (optionalType.getContainedType().isa()) + if (isa(optionalType.getContainedType())) return success(); return checkType(op, optionalType.getContainedType(), actuallyEmitDiagnostics); @@ -134,7 +134,7 @@ static LogicalResult checkType(Operation *op, Type type, // the contained type information. Somehow this slips through and works. // We should be stricter about this and properly infer the contained type // and shape. - if (listType.getContainedType().isa()) + if (isa(listType.getContainedType())) return success(); return checkType(op, listType.getContainedType(), actuallyEmitDiagnostics); } @@ -535,7 +535,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, } target.addDynamicallyLegalOp( [backendLegalOpsSet](OperatorOp opOp) { - auto opName = opOp->getAttr("name").cast().getValue(); + auto opName = cast(opOp->getAttr("name")).getValue(); return backendLegalOpsSet.contains(opName); }); } diff --git a/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp index 147f16c08eb3..c237ede12479 100644 --- a/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp @@ -62,7 +62,7 @@ class MatchQuantizeOperator : public OpRewritePattern { op.getLoc(), op.getOperand(0).getType(), op.getOperand(0), op.getOperand(3), op.getOperand(4)); - auto clampTy = clamp.getType().cast(); + auto clampTy = cast(clamp.getType()); if (!clampTy.hasDtype()) return rewriter.notifyMatchFailure(op, "dequantization has unknown dtype"); diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index 4026d0464dca..cd4b74be678e 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -23,7 +23,7 @@ using namespace mlir::torch; using namespace mlir::torch::Torch; static Value assertNonValueTensor(Value tensor) { - assert(tensor.getType().isa() && + assert(isa(tensor.getType()) && "tensor is expected to be a non-value tensor"); return tensor; } @@ -102,7 +102,7 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock // to use value semantics (which happens for example with ops // that take two aliases as input), then it is possible that the // op no longer generates an alias. - if (userResult.getType().isa()) + if (isa(userResult.getType())) availableAliases.insert(userResult); result.viewLikeOps.push_back(user); } else if (auto copyToValueTensor = dyn_cast(user)) { @@ -177,7 +177,7 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock for (Operation *viewLikeOp : ops.viewLikeOps) { rewriter.modifyOpInPlace(viewLikeOp, [&] { Value result = viewLikeOp->getResult(0); - auto resultType = result.getType().dyn_cast(); + auto resultType = dyn_cast(result.getType()); if (resultType) result.setType(resultType.getWithValueSemantics()); }); @@ -230,7 +230,7 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock if (isViewLikeOp(op)) { // We currently only support view-like ops with one tensor output. if (op->getNumResults() != 1 || - !op->getResult(0).getType().isa()) { + !isa(op->getResult(0).getType())) { return rewriter.notifyMatchFailure( copy, "unsupported: view-like ops must have one tensor output, " "and the tensor output must be the first result"); @@ -242,7 +242,7 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock // non-value tensor and the output being a value tensor. If this is the // case then there is no need to look at the users of the result of the // op. - if (opResult.getType().isa()) { + if (isa(opResult.getType())) { if (operand.getOperandNumber() == 0) { validViewLikeOps.insert(op); llvm::append_range(workList, opResult.getUses()); @@ -339,7 +339,7 @@ class RewriteViewLikeSubgraph for (Operation *op : viewLikeOps) { rewriter.modifyOpInPlace(op, [&]() { if (auto nonValueTensorType = - op->getResult(0).getType().dyn_cast()) { + dyn_cast(op->getResult(0).getType())) { originalTypes[op->getResult(0)] = nonValueTensorType; op->getResult(0).setType(nonValueTensorType.getWithValueSemantics()); } diff --git a/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp b/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp index 279cbc41d4a6..93a44ac33adc 100644 --- a/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp +++ b/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp @@ -30,7 +30,7 @@ class ConvertPrimCallMethodToCall : public OpRewritePattern { LogicalResult matchAndRewrite(PrimCallMethodOp op, PatternRewriter &rewriter) const override { auto classType = symbolTable.lookup( - op.getReceiver().getType().cast().getClassName()); + cast(op.getReceiver().getType()).getClassName()); assert(classType && "malformed module -- missing ClassTypeOp"); func::FuncOp func; for (auto method : classType.getOps()) { @@ -94,7 +94,7 @@ class PrepareForGlobalizeObjectGraphPass ConversionTarget target(*context); target.addIllegalOp(); target.addDynamicallyLegalOp( - [](func::ConstantOp op) { return !op.getType().isa(); }); + [](func::ConstantOp op) { return !isa(op.getType()); }); target.addIllegalOp(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); diff --git a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp index 01a962f0b270..c1e476a80a10 100644 --- a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp @@ -78,7 +78,7 @@ class RecomposeSliceCopy_ : public OpRewritePattern { Value falseVal = rewriter.create(op.getLoc(), false); // Create IndexPut_Op - BaseTensorType tensorType = op.getType().cast(); + BaseTensorType tensorType = cast(op.getType()); Type rangeType = tensorType.getWithSizesAndDtype( {kUnknownSize}, tensorType.getOptionalDtype()); Value range = rewriter.create( @@ -130,8 +130,7 @@ class RecomposeSelectFill_ : public OpRewritePattern { // Create IndexPut_Op // Convert indexNum to indexTensor for the selectOp - BaseTensorType selectOutTy = - selectOp.getType().template cast(); + BaseTensorType selectOutTy = cast(selectOp.getType()); SmallVector empty; auto dtype = getTypeForTorchType(selectOp.getContext(), selectOp.getIndex().getType()); @@ -141,7 +140,7 @@ class RecomposeSelectFill_ : public OpRewritePattern { selectOp.getLoc(), emptyTensorType, selectOp.getIndex()); // Create indicesVector for IndexPut_Op by TorchNone and indexTensor - BaseTensorType tensorType = op->getResultTypes()[0].cast(); + BaseTensorType tensorType = cast(op->getResultTypes()[0]); SmallVector indicesVector(dim, noneVal); indicesVector.push_back(indexTensor); diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index 746b9068284c..8b758a135751 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -26,9 +26,9 @@ static void createOverwriteTensorContents(PatternRewriter &rewriter, Location loc, Value overwriterTensor, Value overwrittenTensor) { Type overwriterTensorType = overwriterTensor.getType(); - Type overwrittenTensorType = overwrittenTensor.getType() - .dyn_cast() - .getWithValueSemantics(); + Type overwrittenTensorType = + dyn_cast(overwrittenTensor.getType()) + .getWithValueSemantics(); if (overwriterTensorType != overwrittenTensorType) { overwriterTensor = rewriter.create( loc, overwrittenTensorType, overwriterTensor); @@ -58,7 +58,7 @@ operatorOpHasValueSemantics(OperatorOp opOp, std::optional extraLibrary) { if (!extraLibrary.has_value()) return false; - auto opName = opOp->getAttr("name").cast().getValue(); + auto opName = cast(opOp->getAttr("name")).getValue(); std::string libFuncName = (mlir::torch::Torch::getLibraryFunctionPrefix( LibraryFunctionKind::HasValueSemantics) + Twine(opName)) @@ -96,8 +96,8 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { opOperand.set(rewriter.create(op->getLoc(), opOperand.get())); } else if (auto listType = dyn_cast(operandType)) { - if (!(listType.getContainedType().isa() || - listType.getContainedType().isa())) + if (!(isa(listType.getContainedType()) || + isa(listType.getContainedType()))) continue; // Construct a new list whose elements are value tensors copied from @@ -116,7 +116,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { // TODO: Handle optional type in list type. if (auto optionalType = - listType.getContainedType().dyn_cast()) { + dyn_cast(listType.getContainedType())) { if (!llvm::all_of(listConstruct.getElements(), [](Value val) { return val.getType().isa(); })) { @@ -129,7 +129,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { auto newListElements = llvm::to_vector(llvm::map_range( listConstruct.getElements(), [&](Value tensor) -> Value { - if (tensor.getType().isa()) { + if (isa(tensor.getType())) { return rewriter.create(op->getLoc(), tensor); } @@ -147,7 +147,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { } else if (auto optionalType = dyn_cast(operandType)) { // TODO: A more general way to handle the optional type is to // introduce a `copy.to_optional_vtensor` op. - if (!optionalType.getContainedType().isa()) + if (!isa(optionalType.getContainedType())) continue; // Create a new optional value whose input is a value tensor copied @@ -160,7 +160,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { "derefine"); } - if (!derefine.getOperand().getType().isa()) + if (!isa(derefine.getOperand().getType())) continue; auto newOperand = rewriter.create( op->getLoc(), derefine.getOperand()); @@ -172,7 +172,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { // Convert all results. rewriter.setInsertionPointAfter(op); for (Value result : op->getResults()) { - auto tensorType = result.getType().dyn_cast(); + auto tensorType = dyn_cast(result.getType()); if (!tensorType) continue; result.setType(tensorType.getWithValueSemantics()); diff --git a/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp b/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp index cfa4e40ee908..373680495f41 100644 --- a/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp +++ b/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp @@ -84,7 +84,7 @@ class RefinePublicReturnPass } } - if (auto tensorType = newOperand.getType().dyn_cast()) { + if (auto tensorType = dyn_cast(newOperand.getType())) { newOperands.push_back( copyTensorToType(builder, returnOp->getLoc(), tensorType.getWithValueSemantics(), newOperand)); diff --git a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp index 8049d8af8d59..3b25e12c3a8e 100644 --- a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp @@ -118,7 +118,7 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable( assert(call.getNumResults() == 1 && "Multiple results are packed in a tuple in Python!"); Value result = call.getResult(0); - if (auto tupleType = result.getType().dyn_cast()) { + if (auto tupleType = dyn_cast(result.getType())) { auto unpack = b.create( loc, tupleType.getContainedTypes(), result); llvm::append_range(unpackedResults, unpack.getResults()); @@ -275,7 +275,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand, // for i in range(len(operand)): // adjusted_list.append(adjust(operand[i])) // return adjusted_list - auto providedType = operand.getType().cast(); + auto providedType = cast(operand.getType()); Value adjustedList = b.create(loc, desiredListType, ValueRange({})); // Create a for-like PrimLoopOp. @@ -312,7 +312,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand, // signature uses `Scalar` (see comments in torch_ods_gen.py for // explanation). if (isa(desiredType) && - operand.getType().isa()) { + isa(operand.getType())) { return b.create(loc, desiredType, operand).getResult(); } diff --git a/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp index 860ae79bdb86..3e9ec336641d 100644 --- a/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp @@ -30,7 +30,7 @@ dtypeFunctionArgsBuilder(OpBuilder &b, Location loc, auto dtypeArgAdjuster = [](OpBuilder &b, Location loc, Value operand, Type desiredType) -> Value { if (isa(desiredType) && - operand.getType().isa()) { + isa(operand.getType())) { Type intType = Torch::IntType::get(b.getContext()); Type sizeListType = Torch::ListType::get(intType); Value size = b.create(loc, sizeListType, operand); diff --git a/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp index 9b1c5e7fdccd..fb9d33123a9c 100644 --- a/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp @@ -41,8 +41,8 @@ shapeFunctionArgsBuilder(OpBuilder &b, Location loc, auto desiredListType = dyn_cast(desiredType); if (!desiredListType) return operand; - if (operand.getType().isa() && - desiredListType.getContainedType().isa()) { + if (isa(operand.getType()) && + isa(desiredListType.getContainedType())) { return b.create(loc, desiredType, operand); } return operand; diff --git a/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp index 05daa41382cd..f1ebeb307976 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp @@ -259,7 +259,7 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp, Type originalResultType = result.getType(); Type updatedType; if (auto originalBaseTensorType = - originalResultType.template dyn_cast()) { + dyn_cast(originalResultType)) { // If we didn't get any new information, there is nothing left for us to do. updatedType = meetTensorTypes(originalBaseTensorType, cast(newResultType)); @@ -267,7 +267,7 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp, return rewriter.notifyMatchFailure( calculateOp, "New type information does not refine old type"); } else if (auto originalResultType = - result.getType().template dyn_cast()) { + dyn_cast(result.getType())) { if (!isa(newResultType)) { return rewriter.notifyMatchFailure( calculateOp, diff --git a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp index d68b0d4bd3a7..6b18af04dca6 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp @@ -35,7 +35,7 @@ static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op, // Calculate the updated type incorporating the new information. Type impliedTypeFromDtype; - if (result.getType().isa()) { + if (isa(result.getType())) { FailureOr torchType = getTorchTypeForScalarType(op->getContext(), dtypeScalarType); if (failed(torchType)) { @@ -45,7 +45,7 @@ static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op, } impliedTypeFromDtype = *torchType; } else if (auto originalResultType = - result.getType().dyn_cast()) { + dyn_cast(result.getType())) { FailureOr builtinType = getTypeForScalarType(op->getContext(), dtypeScalarType); if (failed(builtinType)) { @@ -168,12 +168,12 @@ class RefineNumToTensorScalarOpType using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PrimNumToTensorScalarOp op, PatternRewriter &rewriter) const override { - auto originalResultType = op.getResult().getType().cast(); + auto originalResultType = cast(op.getResult().getType()); if (originalResultType.hasDtype()) return rewriter.notifyMatchFailure( op, "`PrimNumToTensorScalarOp` already has a dtype"); - if (op.getA().getType().isa()) { + if (isa(op.getA().getType())) { return rewriter.notifyMatchFailure(op, "`PrimNumToTensorScalarOp`'s input " "should have concrete Scalar Type."); diff --git a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp index c56376a6c1bc..37ce829cb731 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp @@ -27,7 +27,7 @@ class DecomposeAtenSizeOp : public OpRewritePattern { Location loc = op.getLoc(); Value self = op.getSelf(); MLIRContext *context = op.getContext(); - auto tensorType = self.getType().cast(); + auto tensorType = cast(self.getType()); if (!tensorType.hasSizes()) return rewriter.notifyMatchFailure(op, "unranked tensor"); int64_t rank = tensorType.getSizes().size(); @@ -96,7 +96,7 @@ static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op, sizes.push_back(kUnknownSize); } - auto originalResultType = result.getType().cast(); + auto originalResultType = cast(result.getType()); auto impliedTypesFromShape = cast(originalResultType) .getWithSizesAndDtype(ArrayRef(sizes), diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 3a0117681fec..d634556c98a1 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -44,9 +44,9 @@ bool Torch::getListConstructElements(Value v, SmallVectorImpl &elems) { } torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { - if (type.isa()) + if (isa(type)) return torch_upstream::ScalarType::Float; - if (type.isa()) + if (isa(type)) return torch_upstream::ScalarType::Double; if (type.isSignedInteger(64)) return torch_upstream::ScalarType::Long; @@ -64,11 +64,11 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { return torch_upstream::ScalarType::Byte; if (type.isSignedInteger(8)) return torch_upstream::ScalarType::Char; - if (type.isa()) + if (isa(type)) return torch_upstream::ScalarType::QUInt8; - if (type.isa()) + if (isa(type)) return torch_upstream::ScalarType::QInt8; - if (type.isa()) + if (isa(type)) return torch_upstream::ScalarType::QInt32; if (isa(type)) { mlir::Type complexElemType = cast(type).getElementType(); @@ -185,7 +185,7 @@ Value Torch::getDtypeIntValueForType(PatternRewriter &rewriter, Location loc, // Helper to convert a tensor to a specific scalar type. Value Torch::convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input, Type dtype) { - BaseTensorType origType = input.getType().cast(); + BaseTensorType origType = cast(input.getType()); Type newType = origType.getWithSizesAndDtype(origType.getSizes(), dtype); // `convertIntVal` contains the corresponding integer for the dtype which is // used by the aten.to.dtype op. @@ -202,7 +202,7 @@ bool Torch::isBuiltInType(Type type) { } std::optional Torch::getTensorRank(Value tensor) { - BaseTensorType tensorType = tensor.getType().cast(); + BaseTensorType tensorType = cast(tensor.getType()); if (!tensorType.hasSizes()) return std::nullopt; return tensorType.getSizes().size(); @@ -279,7 +279,7 @@ SmallVector Torch::makeShapeTorchCompatible(ArrayRef shape) { // Return the squeezed tensor or failure. FailureOr Torch::squeezeTensor(PatternRewriter &rewriter, Operation *op, Location loc, int64_t dim, Value input) { - BaseTensorType inputType = input.getType().cast(); + BaseTensorType inputType = cast(input.getType()); if (!inputType.hasSizes()) { return rewriter.notifyMatchFailure(loc, "input tensor must have size"); } @@ -314,7 +314,7 @@ FailureOr Torch::squeezeTensor(PatternRewriter &rewriter, Operation *op, // Return the unsqueezed tensor or failure. FailureOr Torch::unsqueezeTensor(PatternRewriter &rewriter, Operation *op, Value input, Value dim) { - BaseTensorType inputType = input.getType().cast(); + BaseTensorType inputType = cast(input.getType()); if (!inputType.hasSizes()) { return rewriter.notifyMatchFailure(op, "input tensor must have size"); } @@ -348,9 +348,9 @@ void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc, SmallVector &resultShape, SmallVector &resultShapeValue) { SmallVector shapeA{ - inputA.getType().cast().getSizes()}; + cast(inputA.getType()).getSizes()}; SmallVector shapeB{ - inputB.getType().cast().getSizes()}; + cast(inputB.getType()).getSizes()}; unsigned rankA = shapeA.size(); unsigned rankB = shapeB.size(); unsigned minRank = rankA > rankB ? rankB : rankA; @@ -504,9 +504,8 @@ Value Torch::createRank0Tensor(PatternRewriter &rewriter, Location loc, BaseTensorType inputType, Value scalar) { assert(inputType.hasDtype() && "input must have dtype"); SmallVector sizes; - BaseTensorType rank0TensorTy = - inputType.getWithSizesAndDtype(ArrayRef(sizes), inputType.getDtype()) - .cast(); + BaseTensorType rank0TensorTy = cast( + inputType.getWithSizesAndDtype(ArrayRef(sizes), inputType.getDtype())); Value dimList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())), ValueRange{}); @@ -531,9 +530,9 @@ Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) { return rewriter.getF32Type(); if (inputType.isBF16()) return rewriter.getF32Type(); - if (inputType.isa()) + if (isa(inputType)) return rewriter.getF32Type(); - if (inputType.isa()) + if (isa(inputType)) return rewriter.getF64Type(); if (inputType.isFloat8E5M2()) return rewriter.getF32Type(); diff --git a/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp b/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp index e165595fc6e1..a81c27d92845 100644 --- a/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp +++ b/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp @@ -34,9 +34,9 @@ static bool haveSameSizeAndElementType(TensorType lhs, TensorType rhs) { //===----------------------------------------------------------------------===// LogicalResult ToBuiltinTensorOp::verify() { - auto resultType = getResult().getType().cast(); + auto resultType = cast(getResult().getType()); auto operandType = - getOperand().getType().cast().toBuiltinTensor(); + cast(getOperand().getType()).toBuiltinTensor(); if (!haveSameSizeAndElementType(resultType, operandType)) { return emitError() << "operand and result must have the same size and dtype"; @@ -49,7 +49,7 @@ LogicalResult ToBuiltinTensorOp::inferReturnTypes( DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { auto resultType = - operands[0].getType().cast().toBuiltinTensor(); + cast(operands[0].getType()).toBuiltinTensor(); if (!resultType) return failure(); inferredReturnTypes.push_back(resultType); @@ -62,8 +62,8 @@ LogicalResult ToBuiltinTensorOp::inferReturnTypes( LogicalResult FromBuiltinTensorOp::verify() { auto resultType = - getResult().getType().cast().toBuiltinTensor(); - auto operandType = getOperand().getType().cast(); + cast(getResult().getType()).toBuiltinTensor(); + auto operandType = cast(getOperand().getType()); if (!haveSameSizeAndElementType(resultType, operandType)) { return emitError() << "operand and result must have the same size and dtype"; diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp index 1cda55724ee3..947011ea8338 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp @@ -36,7 +36,7 @@ setupValueTensorToBuiltinTensorConversion(ConversionTarget &target, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); - if (!inputs[0].getType().isa()) + if (!isa(inputs[0].getType())) return {}; return builder.create(loc, inputs[0]); }); @@ -44,7 +44,7 @@ setupValueTensorToBuiltinTensorConversion(ConversionTarget &target, Torch::ValueTensorType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); - assert(inputs[0].getType().isa()); + assert(isa(inputs[0].getType())); return builder.create(loc, type, inputs[0]); }; typeConverter.addSourceMaterialization(sourceMaterialization); @@ -64,13 +64,13 @@ static void setupTorchBoolToI1Conversion(ConversionTarget &target, if (!(type.getWidth() == 1 && type.isSignless())) return std::nullopt; assert(inputs.size() == 1); - assert(inputs[0].getType().isa()); + assert(isa(inputs[0].getType())); return builder.create(loc, inputs[0]).getResult(); }); auto sourceMaterialization = [](OpBuilder &builder, Torch::BoolType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); - assert(inputs[0].getType().isa()); + assert(isa(inputs[0].getType())); return builder.create(loc, inputs[0]); }; typeConverter.addSourceMaterialization(sourceMaterialization); @@ -99,7 +99,7 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target, auto sourceMaterialization = [](OpBuilder &builder, Torch::IntType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); - assert(inputs[0].getType().isa()); + assert(isa(inputs[0].getType())); return builder.create(loc, inputs[0]); }; typeConverter.addSourceMaterialization(sourceMaterialization); @@ -116,13 +116,13 @@ static void setupTorchFloatToF64Conversion(ConversionTarget &target, [](OpBuilder &builder, Float64Type type, ValueRange inputs, Location loc) -> std::optional { assert(inputs.size() == 1); - assert(inputs[0].getType().isa()); + assert(isa(inputs[0].getType())); return builder.create(loc, inputs[0]).getResult(); }); auto sourceMaterialization = [](OpBuilder &builder, Torch::FloatType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); - assert(inputs[0].getType().isa()); + assert(isa(inputs[0].getType())); return builder.create(loc, inputs[0]); }; typeConverter.addSourceMaterialization(sourceMaterialization); @@ -153,7 +153,7 @@ static void setupTorchGeneratorToI64Conversion(ConversionTarget &target, auto sourceMaterialization = [](OpBuilder &builder, Torch::GeneratorType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); - assert(inputs[0].getType().isa()); + assert(isa(inputs[0].getType())); return builder.create(loc, inputs[0]); }; typeConverter.addSourceMaterialization(sourceMaterialization); diff --git a/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp index 12e30f287f3f..36292a0f0570 100644 --- a/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp +++ b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp @@ -42,7 +42,7 @@ class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { // get inputs: lhs, rhsQuant, scales, zps Value lhs = adaptor.getOperands()[0]; - auto lhsType = lhs.getType().cast(); + auto lhsType = cast(lhs.getType()); if (!lhsType) { return failure(); } @@ -50,7 +50,7 @@ class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { int lhsReductDimSize = lhsShape.back(); Value rhsQuant = adaptor.getOperands()[1]; - auto rhsType = rhsQuant.getType().cast(); + auto rhsType = cast(rhsQuant.getType()); if (!rhsType) { return failure(); } diff --git a/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp b/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp index 064c87f6e6a8..1e6879530ce6 100644 --- a/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp +++ b/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp @@ -59,7 +59,7 @@ class UnpackQuantizedMatmulWeights if (failed(getConstantIntegerFromDefiningOp(bitWidth, unpackedBitWidth))) return failure(); - auto rhsType = rhs.getType().dyn_cast(); + auto rhsType = dyn_cast(rhs.getType()); if (!rhsType) return failure(); @@ -88,7 +88,7 @@ class UnpackQuantizedMatmulWeights ValueTensorType newRhsType = ValueTensorType::get( rewriter.getContext(), tensorShape, unpackedElementType); - auto elements = constOp.getValueAttr().dyn_cast(); + auto elements = dyn_cast(constOp.getValueAttr()); if (!elements) return failure(); diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index 1cf52144e0a7..3bd16ed38940 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -234,7 +234,7 @@ static LogicalResult bufferizeMLProgramGlobalOp(ml_program::GlobalOp globalOp, if (!globalOp.getValue().has_value()) return globalOp.emitError("global op must have a value"); - RankedTensorType tensorType = globalOp.getType().cast(); + RankedTensorType tensorType = cast(globalOp.getType()); MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); @@ -252,7 +252,7 @@ static LogicalResult bufferizeMLProgramGlobalOp(ml_program::GlobalOp globalOp, static LogicalResult bufferizeMLProgramGlobaLoadOp(ml_program::GlobalLoadOp globalLoadOp, OpBuilder &b, SmallVector &toErase) { - RankedTensorType tensorType = globalLoadOp.getType().cast(); + RankedTensorType tensorType = cast(globalLoadOp.getType()); MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); @@ -271,7 +271,7 @@ bufferizeMLProgramGlobaStoreOp(ml_program::GlobalStoreOp globalStoreOp, OpBuilder &b, SmallVector &toErase) { RankedTensorType tensorType = - globalStoreOp.getValue().getType().cast(); + cast(globalStoreOp.getValue().getType()); MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); @@ -300,7 +300,7 @@ class MLProgramBufferize : public MLProgramBufferizeBase { SmallVector toErase; auto walkResult = module.walk([&](ml_program::GlobalOp op) { - if (auto type = op.getType().dyn_cast()) { + if (auto type = dyn_cast(op.getType())) { if (!type.hasStaticShape()) { // If the ml_program.global has dynamically shaped tensor. op.emitError( @@ -387,8 +387,8 @@ mlir::torch::RefBackend::createExpandOpsForLLVMPass() { Operation *createLinalgCopyOp(OpBuilder &b, Location loc, Value from, Value to) { - auto memrefTypeFrom = from.getType().cast(); - auto memrefTypeTo = to.getType().cast(); + auto memrefTypeFrom = cast(from.getType()); + auto memrefTypeTo = cast(to.getType()); (void)memrefTypeFrom; assert(memrefTypeFrom && memrefTypeTo && memrefTypeFrom.getRank() == memrefTypeTo.getRank());