diff --git a/tests/test_runner/test_amp.py b/tests/test_runner/test_amp.py index 89794f3414..a80c7f35cb 100644 --- a/tests/test_runner/test_amp.py +++ b/tests/test_runner/test_amp.py @@ -5,7 +5,7 @@ import torch.nn as nn import mmengine -from mmengine.device import get_device, is_mlu_available +from mmengine.device import get_device, is_mlu_available, is_npu_available from mmengine.runner import autocast from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION @@ -14,7 +14,22 @@ class TestAmp(unittest.TestCase): def test_autocast(self): - if is_mlu_available(): + if is_npu_available(): + device = 'npu' + with autocast(device_type=device): + # torch.autocast support npu mode. + layer = nn.Conv2d(1, 1, 1).to(device) + res = layer(torch.randn(1, 1, 1, 1).to(device)) + self.assertIn(res.dtype, (torch.bfloat16, torch.float16)) + with autocast(enabled=False, device_type=device): + res = layer(torch.randn(1, 1, 1, 1).to(device)) + self.assertEqual(res.dtype, torch.float32) + # Test with fp32_enabled + with autocast(enabled=False, device_type=device): + layer = nn.Conv2d(1, 1, 1).to(device) + res = layer(torch.randn(1, 1, 1, 1).to(device)) + self.assertEqual(res.dtype, torch.float32) + elif is_mlu_available(): device = 'mlu' with autocast(device_type=device): # torch.autocast support mlu mode.