diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 31a614cfa8b2..f0795d332f21 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -470,40 +470,39 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); }); patterns.onOp( - "Scatter", 9, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - int64_t axis; - if (binder.s64IntegerAttr(axis, "axis", {})) - return rewriter.notifyMatchFailure(binder.op, "axis bind failure"); - - Torch::ValueTensorType resultTy; - Value data, indices, updates; - if (binder.tensorOperandAtIndex(data, 0) || - binder.tensorOperandAtIndex(indices, 1) || - binder.tensorOperandAtIndex(updates, 2) || - binder.tensorResultType(resultTy)) - return failure(); + "Scatter", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + int64_t axis; + if (binder.s64IntegerAttr(axis, "axis", {})) + return rewriter.notifyMatchFailure(binder.op, "axis bind failure"); - auto dataTy = data.getType().cast(), - indicesTy = indices.getType().cast(), - updatesTy = updates.getType().cast(); + Torch::ValueTensorType resultTy; + Value data, indices, updates; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorOperandAtIndex(indices, 1) || + binder.tensorOperandAtIndex(updates, 2) || + binder.tensorResultType(resultTy)) + return failure(); - int64_t dataRank = dataTy.getSizes().size(), - indicesRank = indicesTy.getSizes().size(), - updatesRank = updatesTy.getSizes().size(); + auto dataTy = cast(data.getType()), + indicesTy = cast(indices.getType()), + updatesTy = cast(updates.getType()); - if ((dataRank < 1) || (indicesRank < 1) || (updatesRank < 1) || - (axis < -dataRank) || (axis >= dataRank)) - return failure(); + int64_t dataRank = dataTy.getSizes().size(), + indicesRank = indicesTy.getSizes().size(), + updatesRank = updatesTy.getSizes().size(); - Value axisValue = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(axis)); + if ((dataRank < 1) || (indicesRank < 1) || (updatesRank < 1) || + (axis < -dataRank) || (axis >= dataRank)) + return failure(); + + Value axisValue = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(axis)); - rewriter.replaceOpWithNewOp( - binder.op, resultTy, data, axisValue, indices, updates); + rewriter.replaceOpWithNewOp( + binder.op, resultTy, data, axisValue, indices, updates); - return success(); - }); + return success(); + }); patterns.onOp( "ScatterElements", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {