Skip to content

Commit

Permalink
support the pytorch's maxvit model by adding the aten::swapaxes opera…
Browse files Browse the repository at this point in the history
…tor support.

Co-authored-by: Masahiro Hiramori <[email protected]>

support the pytorch's maxvit model by adding the aten::swapaxes operator support.

Co-authored-by: Masahiro Hiramori <[email protected]>

support the pytorch's maxvit model by adding the aten::swapaxes operator support.

Co-authored-by: Masahiro Hiramori <[email protected]>
  • Loading branch information
nhat-14 and mshr-h committed Nov 8, 2023
1 parent de56d8c commit 65b3754
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4108,6 +4108,7 @@ def create_convert_map(self):
"aten::multinomial": self.multinomial,
"aten::_weight_norm": self.weight_norm,
"aten::copy_": self.inplace_copy,
"aten::swapaxes": self.transpose,
}

def update_convert_map(self, custom_map):
Expand Down
24 changes: 24 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5381,6 +5381,30 @@ def forward(self, x):
verify_model(PartialDimensionInplaceCopy(), [inputs])


@tvm.testing.uses_gpu
def test_swapaxes():
"""test_swapaxes"""
torch.set_grad_enabled(False)
input_shape = [2, 3, 10, 5]

class Swapaxes1(Module):
def forward(self, *args):
return args[0].swapaxes(2, 3)

class Swapaxes2(Module):
def forward(self, *args):
return args[0].swapaxes(-2, -1)

class Swapaxes3(Module):
def forward(self, *args):
return args[0].swapaxes(1, 1)

input_data = torch.rand(input_shape).float()
verify_model(Swapaxes1().float().eval(), input_data=input_data)
verify_model(Swapaxes2().float().eval(), input_data=input_data)
verify_model(Swapaxes3().float().eval(), input_data=input_data)


class TestSetSpan:
"""test structural equal between translated / hand-crafted relay IR with span tagged."""

Expand Down

0 comments on commit 65b3754

Please sign in to comment.