diff --git a/torchrec/distributed/tests/test_infer_shardings.py b/torchrec/distributed/tests/test_infer_shardings.py index d6f95d2db..c112dcda5 100755 --- a/torchrec/distributed/tests/test_infer_shardings.py +++ b/torchrec/distributed/tests/test_infer_shardings.py @@ -1877,6 +1877,11 @@ def test_sharded_quant_fp_ebc_tw(self, weight_dtype: torch.dtype) -> None: gm_script_output = gm_script(*inputs[0]) assert_close(sharded_output, gm_script_output) + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) def test_sharded_quant_fp_ebc_tw_meta(self) -> None: # Simulate inference, take unsharded cpu model and shard on meta # Use PositionWeightedModuleCollection, FP used in production diff --git a/torchrec/distributed/tests/test_pt2.py b/torchrec/distributed/tests/test_pt2.py index ee5aed804..dc3ce705b 100644 --- a/torchrec/distributed/tests/test_pt2.py +++ b/torchrec/distributed/tests/test_pt2.py @@ -311,6 +311,11 @@ def test_sharded_quant_ebc_non_strict_export(self) -> None: # TODO: Fix Unflatten # torch.export.unflatten(ep) + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) def test_sharded_quant_fpebc_non_strict_export(self) -> None: sharded_model, input_kjts = _sharded_quant_ebc_model( local_device="cpu", compute_device="cpu", feature_processor=True diff --git a/torchrec/modules/tests/test_feature_processor_.py b/torchrec/modules/tests/test_feature_processor_.py index 3845f4395..ea2c1bcb8 100644 --- a/torchrec/modules/tests/test_feature_processor_.py +++ b/torchrec/modules/tests/test_feature_processor_.py @@ -61,6 +61,7 @@ def test_populate_weights(self) -> None: weighted_features.lengths(), weighted_features_gm_script.lengths() ) + # TODO: this test is not being run # pyre-ignore @unittest.skipIf( torch.cuda.device_count() <= 0, @@ -132,6 +133,7 @@ def test_populate_weights(self) -> None: empty_fp_kjt.length_per_key(), empty_fp_kjt_gm_script.length_per_key() ) + # TODO: this test is not being run # pyre-ignore @unittest.skipIf( torch.cuda.device_count() <= 0, @@ -151,6 +153,11 @@ def test_rematerialize_from_meta(self) -> None: self.assertTrue(pwmc.position_weights_dict[key] is param) torch.testing.assert_close(param, torch.ones_like(param)) + # pyre-ignore + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "Not enough GPUs available", + ) def test_copy(self) -> None: pwmc = PositionWeightedModuleCollection( max_feature_lengths={"f1": 10, "f2": 10},