diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 81392a08ecd13..402ab592027ca 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -4108,6 +4108,7 @@ def create_convert_map(self): "aten::multinomial": self.multinomial, "aten::_weight_norm": self.weight_norm, "aten::copy_": self.inplace_copy, + "aten::swapaxes": self.transpose, } def update_convert_map(self, custom_map): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index abdbda8e40052..b9c1b6ce9cd10 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -5381,6 +5381,30 @@ def forward(self, x): verify_model(PartialDimensionInplaceCopy(), [inputs]) +@tvm.testing.uses_gpu +def test_swapaxes(): + """test_swapaxes""" + torch.set_grad_enabled(False) + input_shape = [2, 3, 10, 5] + + class Swapaxes1(Module): + def forward(self, *args): + return args[0].swapaxes(2, 3) + + class Swapaxes2(Module): + def forward(self, *args): + return args[0].swapaxes(-2, -1) + + class Swapaxes3(Module): + def forward(self, *args): + return args[0].swapaxes(1, 1) + + input_data = torch.rand(input_shape).float() + verify_model(Swapaxes1().float().eval(), input_data=input_data) + verify_model(Swapaxes2().float().eval(), input_data=input_data) + verify_model(Swapaxes3().float().eval(), input_data=input_data) + + class TestSetSpan: """test structural equal between translated / hand-crafted relay IR with span tagged."""