From 985e7796a4e4c2b939c4c350047db2473fcdc8f2 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 5 Jan 2024 15:16:49 -0800 Subject: [PATCH] [linalg] Added `aten.clamp` support with integers to `torch-to-linalg` (#2718) The lowering for `aten.clamp` did not support integer types. Added support for integer types including a signed integer test. --- .../TorchToLinalg/Uncategorized.cpp | 55 +++++++++++++------ .../test_suite/elementwise.py | 28 ++++++++++ 2 files changed, 65 insertions(+), 18 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 0943534dbd9c..f742ded3f1bd 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1007,13 +1007,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, pred, lhs, rhs); } if (auto clamp = dyn_cast(op)) { - Type dtype = converter->convertType(clamp.getType()) - .cast() - .getElementType(); - if (!dtype.isa()) { - clamp.emitError("unimplemented: non-floating point dtype"); - return nullptr; - } AtenClampOp::Adaptor adaptor(operands); auto min = adaptor.getMin(); auto max = adaptor.getMax(); @@ -1022,19 +1015,45 @@ static Value createLinalgPayloadCalculationForElementwiseOp( clamp.emitError("unimplemented: runtime optional type"); return nullptr; } - auto result = payloadArgs[0]; - if (!min.getType().isa()) { - auto minPromoted = convertScalarToDtype(b, loc, min, dtype); - auto pred = b.create(loc, arith::CmpFPredicate::ULT, - result, minPromoted); - result = b.create(loc, pred, minPromoted, result); + + Type dtype = converter->convertType(clamp.getType()) + .cast() + .getElementType(); + if (!dtype.isa()) { + clamp.emitError("unimplement type for clamp"); + return nullptr; } - if (!max.getType().isa()) { - auto maxPromoted = convertScalarToDtype(b, loc, max, dtype); - auto pred = b.create(loc, arith::CmpFPredicate::UGT, - result, maxPromoted); - result = b.create(loc, pred, maxPromoted, result); + + Type dstOriginalDtype = clamp.getType().cast().getDtype(); + bool isUnsigned = isa(dstOriginalDtype); + if (auto intTy = dstOriginalDtype.dyn_cast()) { + isUnsigned = intTy.isUnsigned(); } + auto cmpSelect = [&](Value input, Value clamp, bool getMax) -> Value { + clamp = convertScalarToDtype(b, loc, clamp, dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/dstOriginalDtype); + + Value pred; + if (dtype.isa()) { + auto cmp = + getMax ? arith::CmpFPredicate::UGT : arith::CmpFPredicate::ULT; + pred = b.create(loc, cmp, input, clamp); + } else if (dtype.isa()) { + auto cmp = + isUnsigned ? arith::CmpIPredicate::ult : arith::CmpIPredicate::slt; + if (getMax) + cmp = arith::invertPredicate(cmp); + pred = b.create(loc, cmp, input, clamp); + } + return b.create(loc, pred, clamp, input); + }; + + auto result = payloadArgs[0]; + if (!min.getType().isa()) + result = cmpSelect(result, min, /*getMax=*/false); + if (!max.getType().isa()) + result = cmpSelect(result, max, /*getMax=*/true); return result; } if (auto clampTensor = dyn_cast(op)) { diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 2b86aed35e52..c18c9103d888 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -988,6 +988,34 @@ def ElementwiseClampTensorIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseClampTensorInt8Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int8, True) + ]) + def forward(self, x): + min = -5 + max = 5 + min_clamp = torch.clamp(x, min) + max_clamp = torch.clamp(x, max=max) + both_clamp = torch.clamp(x, min=min, max=max) + return min_clamp, max_clamp, both_clamp + + +@register_test_case(module_factory=lambda: ElementwiseClampTensorInt8Module()) +def ElementwiseClampTensorInt8Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 5, low=-10, high=10, dtype=torch.int8)) + + +# ============================================================================== + + + class ElementwiseClampMinTensorFloatModule(torch.nn.Module): def __init__(self):