Skip to content

Commit

Permalink
[linalg] Fix torch.aten.add of torch.bool
Browse files Browse the repository at this point in the history
Addition of bools saturate which equates to an `or` operator. Updated to
avoid some noticed downstream failures.
  • Loading branch information
rsuderman committed Nov 1, 2024
1 parent 9c1e3b8 commit 7b9081d
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
3 changes: 3 additions & 0 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
if (isa<mlir::FloatType>(dtype)) {
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
return b.create<arith::AddFOp>(loc, lhs, scaled);
} else if (dtype.isInteger(1)) {
Value scaled = b.create<arith::MulIOp>(loc, rhs, alpha);
return b.create<arith::OrIOp>(loc, lhs, scaled);
} else {
Value scaled = b.create<arith::MulIOp>(loc, rhs, alpha);
return b.create<arith::AddIOp>(loc, lhs, scaled);
Expand Down
29 changes: 29 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,35 @@ def ElementwiseAddModule_basic(module, tu: TestUtils):
# ==============================================================================


# Addition is an interesting special case of a binary op, because under the hood
# it carries a third scalar "alpha" parameter, which needs special handling.
class ElementwiseAddBoolModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([4], torch.bool, True),
([4], torch.bool, True),
]
)
def forward(self, a, b):
return a + b


@register_test_case(module_factory=lambda: ElementwiseAddBoolModule())
def ElementwiseAddBoolModule_basic(module, tu: TestUtils):
module.forward(
torch.tensor([False, False, True, True]),
torch.tensor([False, True, True, False]),
)


# ==============================================================================


class ElementwiseUnsqueezeBroadcastModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 7b9081d

Please sign in to comment.