Skip to content

Commit

Permalink
Replace some depreciated uses of cast (llvm#3343)
Browse files Browse the repository at this point in the history
Contributing towards llvm#3299
  • Loading branch information
zjgarvey authored May 23, 2024
1 parent 5bb1a65 commit 27169dc
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 36 deletions.
40 changes: 20 additions & 20 deletions lib/CAPI/TorchTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ MlirType torchMlirTorchOptionalTypeGet(MlirType containedType) {
}

MlirType torchMlirTorchOptionalTypeGetContained(MlirType t) {
auto type = unwrap(t).cast<Torch::OptionalType>();
auto type = cast<Torch::OptionalType>(unwrap(t));
return wrap(type.getContainedType());
}

Expand All @@ -77,12 +77,12 @@ MlirType torchMlirTorchTupleTypeGet(MlirContext context,
}

size_t torchMlirTorchTupleTypeGetNumTypes(MlirType t) {
auto type = unwrap(t).cast<Torch::TupleType>();
auto type = cast<Torch::TupleType>(unwrap(t));
return type.getContainedTypes().size();
}

MlirType torchMlirTorchTupleTypeGetType(MlirType t, intptr_t pos) {
auto type = unwrap(t).cast<Torch::TupleType>();
auto type = cast<Torch::TupleType>(unwrap(t));
return wrap(type.getContainedTypes()[pos]);
}

Expand All @@ -108,12 +108,12 @@ MlirType torchMlirTorchUnionTypeGet(MlirContext context,
}

size_t torchMlirTorchUnionTypeGetNumTypes(MlirType t) {
auto type = unwrap(t).cast<Torch::UnionType>();
auto type = cast<Torch::UnionType>(unwrap(t));
return type.getContainedTypes().size();
}

MlirType torchMlirTorchUnionTypeGetType(MlirType t, intptr_t pos) {
auto type = unwrap(t).cast<Torch::UnionType>();
auto type = cast<Torch::UnionType>(unwrap(t));
return wrap(type.getContainedTypes()[pos]);
}

Expand All @@ -134,7 +134,7 @@ MlirType torchMlirTorchListTypeGet(MlirType containedType) {
}

MlirType torchMlirTorchListTypeGetContainedType(MlirType t) {
return wrap(unwrap(t).cast<Torch::ListType>().getContainedType());
return wrap(cast<Torch::ListType>(unwrap(t)).getContainedType());
}

MlirTypeID torchMlirTorchListTypeGetTypeID() {
Expand Down Expand Up @@ -297,26 +297,26 @@ MlirType torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation(

MlirType torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr) {
auto attrTensorType =
unwrap(attr).cast<TypedAttr>().getType().cast<RankedTensorType>();
cast<RankedTensorType>(cast<TypedAttr>(unwrap(attr)).getType());
return wrap(Torch::NonValueTensorType::get(attrTensorType.getContext(),
attrTensorType.getShape(),
attrTensorType.getElementType()));
}

int64_t torchMlirTorchNonValueTensorTypeGetRank(MlirType t) {
return unwrap(t).cast<Torch::NonValueTensorType>().getSizes().size();
return cast<Torch::NonValueTensorType>(unwrap(t)).getSizes().size();
}

bool torchMlirTorchNonValueTensorTypeHasSizes(MlirType t) {
return unwrap(t).cast<Torch::NonValueTensorType>().hasSizes();
return cast<Torch::NonValueTensorType>(unwrap(t)).hasSizes();
}

bool torchMlirTorchNonValueTensorTypeHasDtype(MlirType t) {
return unwrap(t).cast<Torch::NonValueTensorType>().hasDtype();
return cast<Torch::NonValueTensorType>(unwrap(t)).hasDtype();
}

int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes) {
auto tensorType = unwrap(t).cast<Torch::NonValueTensorType>();
auto tensorType = cast<Torch::NonValueTensorType>(unwrap(t));
bool hasSizes = tensorType.hasSizes();
if (!hasSizes)
return -1;
Expand All @@ -329,7 +329,7 @@ int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes) {
}

MlirType torchMlirTorchNonValueTensorTypeGetDtype(MlirType t) {
return wrap(unwrap(t).cast<Torch::NonValueTensorType>().getDtype());
return wrap(cast<Torch::NonValueTensorType>(unwrap(t)).getDtype());
}

MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID() {
Expand Down Expand Up @@ -364,26 +364,26 @@ MlirType torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(

MlirType torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr) {
auto attrTensorType =
unwrap(attr).cast<TypedAttr>().getType().cast<RankedTensorType>();
cast<RankedTensorType>(cast<TypedAttr>(unwrap(attr)).getType());
return wrap(Torch::ValueTensorType::get(attrTensorType.getContext(),
attrTensorType.getShape(),
attrTensorType.getElementType()));
}

int64_t torchMlirTorchValueTensorTypeGetRank(MlirType t) {
return unwrap(t).cast<Torch::ValueTensorType>().getSizes().size();
return cast<Torch::ValueTensorType>(unwrap(t)).getSizes().size();
}

bool torchMlirTorchValueTensorTypeHasSizes(MlirType t) {
return unwrap(t).cast<Torch::ValueTensorType>().hasSizes();
return cast<Torch::ValueTensorType>(unwrap(t)).hasSizes();
}

bool torchMlirTorchValueTensorTypeHasDtype(MlirType t) {
return unwrap(t).cast<Torch::ValueTensorType>().hasDtype();
return cast<Torch::ValueTensorType>(unwrap(t)).hasDtype();
}

int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes) {
auto tensorType = unwrap(t).cast<Torch::ValueTensorType>();
auto tensorType = cast<Torch::ValueTensorType>(unwrap(t));
bool hasSizes = tensorType.hasSizes();
if (!hasSizes)
return -1;
Expand All @@ -396,7 +396,7 @@ int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes) {
}

MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t) {
return wrap(unwrap(t).cast<Torch::ValueTensorType>().getDtype());
return wrap(cast<Torch::ValueTensorType>(unwrap(t)).getDtype());
}

MlirTypeID torchMlirTorchValueTensorTypeGetTypeID() {
Expand Down Expand Up @@ -487,12 +487,12 @@ MlirType torchMlirTorchDictTypeGetChecked(MlirContext context, MlirType keyType,
}

MlirType torchMlirTorchDictTypeGetKeyType(MlirType t) {
auto type = unwrap(t).cast<Torch::DictType>();
auto type = cast<Torch::DictType>(unwrap(t));
return wrap(type.getKeyType());
}

MlirType torchMlirTorchDictTypeGetValueType(MlirType t) {
auto type = unwrap(t).cast<Torch::DictType>();
auto type = cast<Torch::DictType>(unwrap(t));
return wrap(type.getValueType());
}

Expand Down
14 changes: 6 additions & 8 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ LogicalResult windowFunctionImpl(OpBinder binder,
// Create an f32 ValueTensorType with thse same size as size, the
// operand
auto shapeOfOperand =
size.getType().dyn_cast<Torch::ValueTensorType>().getOptionalSizes();
dyn_cast<Torch::ValueTensorType>(size.getType()).getOptionalSizes();
auto f32ResultType = rewriter.getType<Torch::ValueTensorType>(
shapeOfOperand, rewriter.getF32Type());
Value periodicSizeFloat = b.create<Torch::AtenToDtypeOp>(
Expand Down Expand Up @@ -897,8 +897,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
}

if (DenseResourceElementsAttr attr =
binder.op->getAttr("torch.onnx.value")
.dyn_cast_or_null<DenseResourceElementsAttr>()) {
dyn_cast_or_null<DenseResourceElementsAttr>(
binder.op->getAttr("torch.onnx.value"))) {
// Bytes are stored in little endian order. Big endian support will
// require swizzling.
if (!Endian::little) {
Expand Down Expand Up @@ -926,8 +926,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
return success();
}

if (ElementsAttr attr = binder.op->getAttr("torch.onnx.value")
.dyn_cast_or_null<ElementsAttr>()) {
if (ElementsAttr attr = dyn_cast_or_null<ElementsAttr>(
binder.op->getAttr("torch.onnx.value"))) {
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
binder.op, resultType, attr);
return success();
Expand Down Expand Up @@ -2283,9 +2283,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.tensorResultType(resultType))
return failure();
Type listElemType =
tensors[0]
.getType()
.cast<Torch::BaseTensorType>()
cast<Torch::BaseTensorType>(tensors[0].getType())
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
/*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType);
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
}

auto conditionType =
conditionTensor.getType().cast<Torch::ValueTensorType>();
cast<Torch::ValueTensorType>(conditionTensor.getType());
if (!conditionType || conditionType.getSizes().size() != 1)
return rewriter.notifyMatchFailure(
binder.op, "condition must have one single element per "
Expand Down
12 changes: 5 additions & 7 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1875,10 +1875,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.op, "Axes should be the same size of starts and ends");
}

auto stepsTy = steps.getType()
.cast<Torch::ValueTensorType>()
.toBuiltinTensor()
.dyn_cast<RankedTensorType>();
auto stepsTy = dyn_cast<RankedTensorType>(
cast<Torch::ValueTensorType>(steps.getType()).toBuiltinTensor());

if (!(stepsTy && stepsTy.getDimSize(0) == endsTy.getDimSize(0)))
return rewriter.notifyMatchFailure(
Expand Down Expand Up @@ -2804,7 +2802,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
Value modeStrValue;

auto extract = [&rewriter, &binder](Value x, Value v) {
auto xTy = x.getType().cast<Torch::ValueTensorType>();
auto xTy = cast<Torch::ValueTensorType>(x.getType());
Type extractTy = rewriter.getType<Torch::FloatType>();
if (isa<IntegerType>(xTy.getDtype()))
extractTy = rewriter.getType<Torch::IntType>();
Expand All @@ -2818,7 +2816,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
auto sizes =
dyn_cast<Torch::ValueTensorType>(operand.getType()).getSizes();
Torch::BaseTensorType operandType =
operand.getType().cast<Torch::BaseTensorType>();
cast<Torch::BaseTensorType>(operand.getType());

SmallVector<int64_t> selectSizes;
selectSizes.push_back(1);
Expand All @@ -2835,7 +2833,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
Value item = extract(operand, ext);
itemList.push_back(item);
}
auto xTy = operand.getType().cast<Torch::ValueTensorType>();
auto xTy = cast<Torch::ValueTensorType>(operand.getType());
Value ValueList;
if (isa<IntegerType>(xTy.getDtype())) {
ValueList = rewriter.create<Torch::PrimListConstructOp>(
Expand Down

0 comments on commit 27169dc

Please sign in to comment.