Skip to content

Commit

Permalink
Fix errors with earlier torchvision versions
Browse files Browse the repository at this point in the history
Signed-off-by: Mark Graham <[email protected]>
  • Loading branch information
marksgraham committed Aug 2, 2023
1 parent c118d10 commit a4e16c8
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion tests/test_perceptual_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from parameterized import parameterized

from monai.losses import PerceptualLoss
from monai.utils import optional_import
from tests.utils import SkipIfBeforePyTorchVersion

_, has_torchvision = optional_import("torchvision")
TEST_CASES = [
[{"spatial_dims": 2, "network_type": "squeeze"}, (2, 1, 64, 64), (2, 1, 64, 64)],
[
Expand Down Expand Up @@ -45,6 +48,8 @@
]


@SkipIfBeforePyTorchVersion((1, 10))
@unittest.skipUnless(has_torchvision, "Requires torchvision")
class TestPerceptualLoss(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_shape(self, input_param, input_shape, target_shape):
Expand Down Expand Up @@ -79,4 +84,4 @@ def test_medicalnet_on_2d_data(self):


if __name__ == "__main__":
unittest.main()
unittest.main()

0 comments on commit a4e16c8

Please sign in to comment.