Skip to content

Commit

Permalink
[Torch] emit aten.ne.str and add folder (llvm#3242)
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyunqu authored and archana-ramalingam committed May 8, 2024
1 parent a763f77 commit 74b9709
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 1 deletion.
41 changes: 40 additions & 1 deletion lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4715,6 +4715,45 @@ LogicalResult AtenPermuteOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// PrimsConvertElementTypeOp
//===----------------------------------------------------------------------===//

OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) {
auto inputType = cast<BaseTensorType>(getA().getType());
auto outputType = cast<BaseTensorType>(getResult().getType());
if (inputType != outputType)
return nullptr;
if (!inputType.hasDtype() || !outputType.hasDtype())
return nullptr;
if (inputType.getDtype() != outputType.getDtype())
return nullptr;
return getA();
}

//===----------------------------------------------------------------------===//
// AtenMaxPool2dWithIndicesOp
//===----------------------------------------------------------------------===//

void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](AtenMaxPool2dWithIndicesOp op, PatternRewriter &rewriter) {
if (!op.getResult1().use_empty()) {
return rewriter.notifyMatchFailure(
op, "result1 of MaxPool2dWithIndices should be unused");
}

Value result = rewriter.create<Torch::AtenMaxPool2dOp>(
op->getLoc(), op.getResult0().getType(), op.getSelf(),
op.getKernelSize(), op.getStride(), op.getPadding(), op.getDilation(),
op.getCeilMode());

op.getResult0().replaceAllUsesWith(result);
rewriter.eraseOp(op);
return success();
});
}

//===----------------------------------------------------------------------===//
// AtenLinalgCrossOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -4967,4 +5006,4 @@ LogicalResult InitializeGlobalSlotsOp::verify() {
if (getInitialValues().size() != getSlotSymNames().size())
return emitOpError("expected number of operands to match number of slots");
return success();
}
}
41 changes: 41 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2974,3 +2974,44 @@ func.func @aten_log$fold_splat_f32() -> !torch.vtensor<[4], f32> {
%result = torch.aten.log %cst : !torch.vtensor<[4], f32> -> !torch.vtensor<[4], f32>
return %result : !torch.vtensor<[4], f32>
}

// -----

// CHECK-LABEL: func.func @torch.prims.convert_element_type$fold(
// CHECK: %[[ARG:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],f32> {
// CHECK: return %[[ARG]] : !torch.vtensor<[64],f32>
func.func @torch.prims.convert_element_type$fold(%arg0: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],f32> {
%int6 = torch.constant.int 6
%0 = torch.prims.convert_element_type %arg0, %int6 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32>
return %0 : !torch.vtensor<[64],f32>
}

// -----

// CHECK-LABEL: func.func @torch.prims.convert_element_type$no_fold(
// CHECK: %[[ARG:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],si32> {
// CHECK: %[[RET:.*]] = torch.prims.convert_element_type %[[ARG]], %{{.*}} : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],si32>
// CHECK: return %[[RET]] : !torch.vtensor<[64],si32>
func.func @torch.prims.convert_element_type$no_fold(%arg0: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],si32> {
%int6 = torch.constant.int 6
%0 = torch.prims.convert_element_type %arg0, %int6 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],si32>
return %0 : !torch.vtensor<[64],si32>
}

// -----

// CHECK-LABEL: @torch.aten.max_pool2d_with_indices$canonicalize(
// CHECK: %[[ARG:.*]]: !torch.vtensor<[10,64,112,112],f32>) -> !torch.vtensor<[10,64,56,56],f32> {
// CHECK: %[[RET:.*]] = torch.aten.max_pool2d %[[ARG]]
// CHECK: return %[[RET]] : !torch.vtensor<[10,64,56,56],f32>
func.func @torch.aten.max_pool2d_with_indices$canonicalize(%arg0: !torch.vtensor<[10,64,112,112],f32>) -> !torch.vtensor<[10,64,56,56],f32> {
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%29 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%30 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%31 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%result0, %result1 = torch.aten.max_pool2d_with_indices %arg0, %29, %30, %31, %31, %false : !torch.vtensor<[10,64,112,112],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[10,64,56,56],f32>, !torch.vtensor<[10,64,56,56],si64>
return %result0 : !torch.vtensor<[10,64,56,56],f32>
}

0 comments on commit 74b9709

Please sign in to comment.