diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index f22be10c12b1..9c2bf27f15fe 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1816,4 +1816,56 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, input); return success(); }); + + patterns.onOp( + "Hardmax", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // onnx.Hardmax can be expanded into the following python code: + // + // import torch.nn.functional as F + // def hardmax(tensor, dim=-1): + // maximums = torch.argmax(tensor, dim=dim, keepdim=False) + // return F.one_hot(maximums) + // + // Given an example input: + // tensor([[1, 2, 3], + // [4, 6, 5], + // [9, 8, 7]]) + // Above code yields the following: + // tensor([[0, 0, 1], + // [0, 1, 0], + // [1, 0, 0]]) + + Torch::ValueTensorType resultType; + int64_t axisValue; + Value input, axis; + if (binder.tensorOperand(input) || + binder.s64IntegerAttr(axisValue, "axis") || + binder.tensorResultType(resultType)) + return failure(); + + auto loc = binder.getLoc(); + + std::optional axisIntTorch = + onnxDtypeIntToTorchDtypeInt(axisValue); + if (!axisIntTorch.has_value()) + return rewriter.notifyMatchFailure( + binder.op, "unimplemented support for the given axis conversion"); + axis = rewriter.create( + loc, rewriter.getI64IntegerAttr(axisIntTorch.value())); + + // torch.argmax + Value constKeepDims = rewriter.create( + loc, rewriter.getType(), + rewriter.getBoolAttr(false)); + Value argmax = rewriter.create( + loc, resultType, input, axis, constKeepDims); + + // one_hot + Value oneInt = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + rewriter.replaceOpWithNewOp(binder.op, resultType, + argmax, oneInt); + + return success(); + }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 4214d3f222a1..2e975c4006aa 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -1025,3 +1025,13 @@ func.func @test_hardswish(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor< %0 = torch.operator "onnx.HardSwish"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_hardmax +func.func @test_hardmax(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[ARGMAX:.+]] = torch.aten.argmax %arg0, %int6, %false : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.bool -> !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.one_hot %[[ARGMAX]], %int1 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Hardmax"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +}