Skip to content

Commit

Permalink
[linalg] Added aten.clamp support with integers to torch-to-linalg (
Browse files Browse the repository at this point in the history
#2718)

The lowering for `aten.clamp` did not support integer types. Added
support for integer types including a signed integer test.
  • Loading branch information
rsuderman authored Jan 5, 2024
1 parent 6096fcb commit 985e779
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 18 deletions.
55 changes: 37 additions & 18 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1007,13 +1007,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<arith::SelectOp>(loc, pred, lhs, rhs);
}
if (auto clamp = dyn_cast<AtenClampOp>(op)) {
Type dtype = converter->convertType(clamp.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::FloatType>()) {
clamp.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
AtenClampOp::Adaptor adaptor(operands);
auto min = adaptor.getMin();
auto max = adaptor.getMax();
Expand All @@ -1022,19 +1015,45 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
clamp.emitError("unimplemented: runtime optional type");
return nullptr;
}
auto result = payloadArgs[0];
if (!min.getType().isa<Torch::NoneType>()) {
auto minPromoted = convertScalarToDtype(b, loc, min, dtype);
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
result, minPromoted);
result = b.create<arith::SelectOp>(loc, pred, minPromoted, result);

Type dtype = converter->convertType(clamp.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::FloatType, mlir::IntegerType>()) {
clamp.emitError("unimplement type for clamp");
return nullptr;
}
if (!max.getType().isa<Torch::NoneType>()) {
auto maxPromoted = convertScalarToDtype(b, loc, max, dtype);
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
result, maxPromoted);
result = b.create<arith::SelectOp>(loc, pred, maxPromoted, result);

Type dstOriginalDtype = clamp.getType().cast<BaseTensorType>().getDtype();
bool isUnsigned = isa<QUInt8Type>(dstOriginalDtype);
if (auto intTy = dstOriginalDtype.dyn_cast<IntegerType>()) {
isUnsigned = intTy.isUnsigned();
}
auto cmpSelect = [&](Value input, Value clamp, bool getMax) -> Value {
clamp = convertScalarToDtype(b, loc, clamp, dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/dstOriginalDtype);

Value pred;
if (dtype.isa<mlir::FloatType>()) {
auto cmp =
getMax ? arith::CmpFPredicate::UGT : arith::CmpFPredicate::ULT;
pred = b.create<arith::CmpFOp>(loc, cmp, input, clamp);
} else if (dtype.isa<mlir::IntegerType>()) {
auto cmp =
isUnsigned ? arith::CmpIPredicate::ult : arith::CmpIPredicate::slt;
if (getMax)
cmp = arith::invertPredicate(cmp);
pred = b.create<arith::CmpIOp>(loc, cmp, input, clamp);
}
return b.create<arith::SelectOp>(loc, pred, clamp, input);
};

auto result = payloadArgs[0];
if (!min.getType().isa<Torch::NoneType>())
result = cmpSelect(result, min, /*getMax=*/false);
if (!max.getType().isa<Torch::NoneType>())
result = cmpSelect(result, max, /*getMax=*/true);
return result;
}
if (auto clampTensor = dyn_cast<AtenClampTensorOp>(op)) {
Expand Down
28 changes: 28 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,34 @@ def ElementwiseClampTensorIntModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseClampTensorInt8Module(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.int8, True)
])
def forward(self, x):
min = -5
max = 5
min_clamp = torch.clamp(x, min)
max_clamp = torch.clamp(x, max=max)
both_clamp = torch.clamp(x, min=min, max=max)
return min_clamp, max_clamp, both_clamp


@register_test_case(module_factory=lambda: ElementwiseClampTensorInt8Module())
def ElementwiseClampTensorInt8Module_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 5, low=-10, high=10, dtype=torch.int8))


# ==============================================================================



class ElementwiseClampMinTensorFloatModule(torch.nn.Module):

def __init__(self):
Expand Down

0 comments on commit 985e779

Please sign in to comment.