Skip to content

Commit

Permalink
Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3243)
Browse files Browse the repository at this point in the history
  • Loading branch information
penguin-wwy authored Apr 27, 2024
1 parent 466618e commit 6679728
Show file tree
Hide file tree
Showing 56 changed files with 936 additions and 983 deletions.
25 changes: 11 additions & 14 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter,
int64_t dimA, int64_t dimB,
Value &transposed) {
Type transposedType;
if (failed(getTransposedType(input.getType().cast<Torch::BaseTensorType>(),
if (failed(getTransposedType(cast<Torch::BaseTensorType>(input.getType()),
dimA, dimB, transposedType)))
return failure();
Value cstDimA = rewriter.create<Torch::ConstantIntOp>(
Expand Down Expand Up @@ -554,7 +554,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
// conversions which are not supported in Torch-MLIR right now.

Torch::ValueTensorType targetTy =
target.getType().cast<Torch::ValueTensorType>();
cast<Torch::ValueTensorType>(target.getType());
if (!targetTy.hasDtype()) {
return rewriter.notifyMatchFailure(binder.op,
"target tensor must have a dtype");
Expand Down Expand Up @@ -753,9 +753,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.tensorResultType(resultType))
return failure();
Type listElemType =
tensors[0]
.getType()
.cast<Torch::BaseTensorType>()
cast<Torch::BaseTensorType>(tensors[0].getType())
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
/*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType);
Expand Down Expand Up @@ -869,7 +867,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.tensorResultType(resultType))
return failure();

auto weightTensorType = weight.getType().cast<Torch::ValueTensorType>();
auto weightTensorType = cast<Torch::ValueTensorType>(weight.getType());
if (!weightTensorType || !weightTensorType.hasSizes()) {
return rewriter.notifyMatchFailure(
binder.op, "Expected weight type having sizes");
Expand Down Expand Up @@ -1188,7 +1186,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.tensorResultType(resultType))
return failure();

auto weightTensorType = weight.getType().cast<Torch::ValueTensorType>();
auto weightTensorType = cast<Torch::ValueTensorType>(weight.getType());
if (!weightTensorType || !weightTensorType.hasSizes()) {
return rewriter.notifyMatchFailure(
binder.op, "Expected weight type having sizes");
Expand Down Expand Up @@ -1427,7 +1425,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.customOpNameStringAttr(mode, "mode", "DCR") ||
binder.tensorResultType(resultType))
return failure();
auto inputTy = input.getType().dyn_cast<Torch::BaseTensorType>();
auto inputTy = dyn_cast<Torch::BaseTensorType>(input.getType());
if (!inputTy || !inputTy.hasSizes()) {
return rewriter.notifyMatchFailure(
binder.op, "Expected input type having sizes");
Expand Down Expand Up @@ -1536,9 +1534,9 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
Value scale = operands[1];
Value zeropoint = operands[2];

auto operandTy = operand.getType().cast<Torch::ValueTensorType>();
auto operandTy = cast<Torch::ValueTensorType>(operand.getType());

auto scaleTy = scale.getType().dyn_cast<Torch::ValueTensorType>();
auto scaleTy = dyn_cast<Torch::ValueTensorType>(scale.getType());
if (!scaleTy || !scaleTy.hasSizes())
return rewriter.notifyMatchFailure(binder.op, "requires known rank");
if (!resultType.hasDtype())
Expand Down Expand Up @@ -1611,7 +1609,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
ratio = rewriter.create<Torch::AtenFloatImplicitOp>(loc, operands[1]);
Value trainVal = operands[2];
auto trainTensorType =
trainVal.getType().dyn_cast<Torch::BaseTensorType>();
dyn_cast<Torch::BaseTensorType>(trainVal.getType());
if (!trainTensorType)
return rewriter.notifyMatchFailure(binder.op,
"train tensor must have a type");
Expand All @@ -1629,8 +1627,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(

if (auto valueTensorLiteralOp =
trainVal.getDefiningOp<Torch::ValueTensorLiteralOp>()) {
auto val = valueTensorLiteralOp.getValue()
.cast<DenseElementsAttr>()
auto val = cast<DenseElementsAttr>(valueTensorLiteralOp.getValue())
.getSplatValue<bool>();
trainingMode = rewriter.create<Torch::ConstantBoolOp>(loc, val);
} else {
Expand Down Expand Up @@ -2072,7 +2069,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
dyn_cast<Torch::ValueTensorType>(shape.getType()).getSizes();
SmallVector<Value> dimList;
Torch::BaseTensorType shapeType =
shape.getType().cast<Torch::BaseTensorType>();
cast<Torch::BaseTensorType>(shape.getType());
Type selectResultType = rewriter.getType<Torch::ValueTensorType>(
ArrayRef<int64_t>({}), shapeType.getOptionalDtype());
Value zero = rewriter.create<Torch::ConstantIntOp>(
Expand Down
18 changes: 9 additions & 9 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
return rewriter.notifyMatchFailure(
binder.op, "operand grid_sampler bind failure");

auto inputTensorType = input.getType().cast<Torch::ValueTensorType>();
auto inputTensorType = cast<Torch::ValueTensorType>(input.getType());
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
uint32_t inputRank = inputShape.size();
auto gridTensorType = grid.getType().cast<Torch::ValueTensorType>();
auto gridTensorType = cast<Torch::ValueTensorType>(grid.getType());
ArrayRef<int64_t> gridShape = gridTensorType.getSizes();
uint32_t gridRank = gridShape.size();

Expand Down Expand Up @@ -233,7 +233,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
axis = rank + axis;
}
// need input type and sizes to flatten/unflatten later.
auto inputTy = input.getType().cast<Torch::ValueTensorType>();
auto inputTy = cast<Torch::ValueTensorType>(input.getType());
if (!inputTy || !inputTy.hasSizes())
return rewriter.notifyMatchFailure(
binder.op, "failed to get input type or sizes");
Expand Down Expand Up @@ -1065,7 +1065,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));

auto transpose = [&](Value m) -> Value {
auto tty = m.getType().cast<Torch::ValueTensorType>();
auto tty = cast<Torch::ValueTensorType>(m.getType());
auto shape = tty.getOptionalSizes();
if (shape.has_value()) {
llvm::SmallVector<int64_t> newShape(shape.value());
Expand Down Expand Up @@ -1134,7 +1134,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.tensorResultType(resultType))
return failure();

auto inputTensorType = operand.getType().cast<Torch::ValueTensorType>();
auto inputTensorType = cast<Torch::ValueTensorType>(operand.getType());
if (!inputTensorType || !inputTensorType.hasSizes()) {
return rewriter.notifyMatchFailure(
binder.op, "Expected input type having sizes");
Expand Down Expand Up @@ -1228,7 +1228,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
rank = *maybeRank;
SmallVector<Value> normalized;
axis = Torch::toPositiveDim(axis, rank);
auto xType = x.getType().cast<Torch::ValueTensorType>();
auto xType = cast<Torch::ValueTensorType>(x.getType());
if (!xType.hasSizes()) {
return rewriter.notifyMatchFailure(
binder.op, "Expected input (X) to have sizes");
Expand Down Expand Up @@ -1307,7 +1307,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(

// Get pads shape and rank. The pads tensor is expected to be 1-D
// tensor.
auto padsTensorType = pads.getType().cast<Torch::ValueTensorType>();
auto padsTensorType = cast<Torch::ValueTensorType>(pads.getType());
if (!padsTensorType || !padsTensorType.hasSizes()) {
return rewriter.notifyMatchFailure(binder.op,
"Expect non empty pad tensor");
Expand All @@ -1323,7 +1323,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
// As per onnx.Pad documentation, padSize = 2*num_data_axes
// (if axes param not passed). Need to be updated when adding
// support for `axes` param.
auto dataOpTy = data.getType().cast<Torch::ValueTensorType>();
auto dataOpTy = cast<Torch::ValueTensorType>(data.getType());
TensorType dataTensor = dataOpTy.toBuiltinTensor();
if (!dataTensor || !dataTensor.hasRank())
return rewriter.notifyMatchFailure(
Expand All @@ -1350,7 +1350,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
}

if (!constantValue) {
auto dataTensorType = data.getType().cast<Torch::ValueTensorType>();
auto dataTensorType = cast<Torch::ValueTensorType>(data.getType());
if (dataTensorType.getDtype().isa<IntegerType>())
constantValue = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Expand Down
Loading

0 comments on commit 6679728

Please sign in to comment.