Skip to content

Commit

Permalink
[Relay][Pytorch] Add support for aten::bitwise_and (#16105)
Browse files Browse the repository at this point in the history
add support for aten::bitwise_and
  • Loading branch information
mshr-h authored Nov 10, 2023
1 parent f9ac3b9 commit 95d769e
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
9 changes: 9 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2327,6 +2327,14 @@ def bitwise_xor(self, inputs, input_types):

return _op.bitwise_xor(lhs, rhs)

def bitwise_and(self, inputs, input_types):
lhs = inputs[0]
rhs = inputs[1]
lhs = _op.cast(lhs, "bool") if input_types[0] == "bool" else _op.cast(lhs, "int")
rhs = _op.cast(rhs, "bool") if input_types[1] == "bool" else _op.cast(rhs, "int")

return _op.bitwise_and(lhs, rhs)

def logical_not(self, inputs, input_types):
data = _wrap_const(inputs[0])
return _op.logical_not(_op.cast(data, "bool"))
Expand Down Expand Up @@ -4033,6 +4041,7 @@ def create_convert_map(self):
"aten::logical_xor": self.logical_xor,
"aten::bitwise_not": self.bitwise_not,
"aten::bitwise_xor": self.bitwise_xor,
"aten::bitwise_and": self.bitwise_and,
"aten::Bool": self.Bool,
"aten::Float": self.Float,
"aten::rsub": self.rsub,
Expand Down
27 changes: 27 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3695,6 +3695,33 @@ def forward(self, *args):
verify_model(BitwiseXor2().float().eval(), input_data=[lhs])


def test_forward_bitwise_and():
"""test_forward_bitwise_and"""
torch.set_grad_enabled(False)

class BitwiseAnd1(Module):
def forward(self, *args):
return torch.bitwise_and(args[0], args[1])

class BitwiseAnd2(Module):
def forward(self, *args):
rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
if torch.cuda.is_available():
rhs = rhs.cuda()
return torch.bitwise_and(args[0], rhs)

lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
verify_model(BitwiseAnd1().float().eval(), input_data=[lhs, rhs])

lhs = torch.tensor([True, True, False])
rhs = torch.tensor([False, True, False])
verify_model(BitwiseAnd1().float().eval(), input_data=[lhs, rhs])

lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
verify_model(BitwiseAnd2().float().eval(), input_data=[lhs])


@tvm.testing.uses_gpu
def test_forward_logical_xor():
"""test_forward_logical_xor"""
Expand Down

0 comments on commit 95d769e

Please sign in to comment.