Skip to content

Commit

Permalink
[NFC] Fix member cast change to global for landing collision (llvm#3407)
Browse files Browse the repository at this point in the history
A PR landed when moving away from a deprecated cast function. Updated
the corresponding lines to pass.
  • Loading branch information
rsuderman authored May 31, 2024
1 parent 878ba72 commit 617b00b
Showing 1 changed file with 27 additions and 28 deletions.
55 changes: 27 additions & 28 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,40 +470,39 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
return success();
});
patterns.onOp(
"Scatter", 9,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
int64_t axis;
if (binder.s64IntegerAttr(axis, "axis", {}))
return rewriter.notifyMatchFailure(binder.op, "axis bind failure");

Torch::ValueTensorType resultTy;
Value data, indices, updates;
if (binder.tensorOperandAtIndex(data, 0) ||
binder.tensorOperandAtIndex(indices, 1) ||
binder.tensorOperandAtIndex(updates, 2) ||
binder.tensorResultType(resultTy))
return failure();
"Scatter", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
int64_t axis;
if (binder.s64IntegerAttr(axis, "axis", {}))
return rewriter.notifyMatchFailure(binder.op, "axis bind failure");

auto dataTy = data.getType().cast<Torch::ValueTensorType>(),
indicesTy = indices.getType().cast<Torch::ValueTensorType>(),
updatesTy = updates.getType().cast<Torch::ValueTensorType>();
Torch::ValueTensorType resultTy;
Value data, indices, updates;
if (binder.tensorOperandAtIndex(data, 0) ||
binder.tensorOperandAtIndex(indices, 1) ||
binder.tensorOperandAtIndex(updates, 2) ||
binder.tensorResultType(resultTy))
return failure();

int64_t dataRank = dataTy.getSizes().size(),
indicesRank = indicesTy.getSizes().size(),
updatesRank = updatesTy.getSizes().size();
auto dataTy = cast<Torch::ValueTensorType>(data.getType()),
indicesTy = cast<Torch::ValueTensorType>(indices.getType()),
updatesTy = cast<Torch::ValueTensorType>(updates.getType());

if ((dataRank < 1) || (indicesRank < 1) || (updatesRank < 1) ||
(axis < -dataRank) || (axis >= dataRank))
return failure();
int64_t dataRank = dataTy.getSizes().size(),
indicesRank = indicesTy.getSizes().size(),
updatesRank = updatesTy.getSizes().size();

Value axisValue = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(axis));
if ((dataRank < 1) || (indicesRank < 1) || (updatesRank < 1) ||
(axis < -dataRank) || (axis >= dataRank))
return failure();

Value axisValue = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(axis));

rewriter.replaceOpWithNewOp<Torch::AtenScatterSrcOp>(
binder.op, resultTy, data, axisValue, indices, updates);
rewriter.replaceOpWithNewOp<Torch::AtenScatterSrcOp>(
binder.op, resultTy, data, axisValue, indices, updates);

return success();
});
return success();
});
patterns.onOp(
"ScatterElements", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Expand Down

0 comments on commit 617b00b

Please sign in to comment.