diff --git a/tests/test_model/test_wrappers/test_model_wrapper.py b/tests/test_model/test_wrappers/test_model_wrapper.py index ead52f6a0b..ea657acac1 100644 --- a/tests/test_model/test_wrappers/test_model_wrapper.py +++ b/tests/test_model/test_wrappers/test_model_wrapper.py @@ -11,7 +11,7 @@ from mmengine.dist import all_gather, broadcast from mmengine.model import (BaseDataPreprocessor, BaseModel, ExponentialMovingAverage, - MMDistributedDataParallel, MMPipelineParallel, + MMDistributedDataParallel, MMSeparateDistributedDataParallel) from mmengine.optim import AmpOptimWrapper, OptimWrapper, OptimWrapperDict from mmengine.testing import assert_allclose @@ -48,24 +48,6 @@ def forward(self, inputs, data_sample=None, mode='tensor'): return x -class ToyLinearModel(BaseModel): - # because the flop analyzer cannot analyze the conv layer - def __init__(self): - super().__init__(data_preprocessor=ToyDataPreprocessor()) - self.linear1 = nn.Linear(1, 1) - self.linear2 = nn.Linear(1, 2) - - def forward(self, inputs, data_sample=None, mode='tensor'): - x = self.linear1(inputs) - x = self.linear2(x) - if mode == 'loss': - return dict(loss=x) - elif mode == 'predict': - return x - else: - return x - - class ComplexModel(BaseModel): def __init__(self): @@ -316,36 +298,3 @@ def test_test_step(self): data = dict(inputs=inputs, data_sample=MagicMock()) predictions = fsdp_model.test_step(data) self.assertIsInstance(predictions, torch.Tensor) - - -@unittest.skipIf( - torch.cuda.device_count() < 2, - reason='need 2 or more gpu to test pipeline parallel') -@unittest.skipIf( - digit_version(TORCH_VERSION) < digit_version('2.0.0'), - reason='pipeline parallelism needs Pytorch 2.0.0 or higher') -class TestMMPipelineParallel(unittest.TestCase): - - def test_init(self): - model = ToyLinearModel() - model = MMPipelineParallel(model, num_pipelines=2) - self.assertEqual(model.model.linear1.weight.device, - torch.device('meta')) - self.assertEqual(model.model.linear2.weight.device, - torch.device('meta')) - - def test_forward(self): - model = ToyLinearModel() - model = MMPipelineParallel(model, num_pipelines=2) - inputs = torch.randn(6, 1) - self.assertIsInstance(model(inputs), torch.Tensor) - self.assertEqual(model.model.linear1.weight.device, - torch.device('cuda:0')) - self.assertEqual(model.model.linear2.weight.device, - torch.device('cuda:1')) - self.assertEqual(model(inputs).shape, torch.Size([6, 2])) - self.assertEqual(model(inputs).device, torch.device('cuda:1')) - self.assertIn('linear1', model.device_map) - self.assertIn('linear1', model.device_map) - self.assertEqual(model.device_map['linear1']['part_id'], 0) - self.assertEqual(model.device_map['linear2']['part_id'], 1)