Skip to content

Commit

Permalink
update resnet tests and deployment files
Browse files Browse the repository at this point in the history
Signed-off-by: vgrau98 <[email protected]>
  • Loading branch information
vgrau98 committed Oct 9, 2023
1 parent 5ac3627 commit 3735733
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 39 deletions.
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ opencv-python-headless
onnx>=1.13.0
onnxruntime; python_version <= '3.10'
zarr
huggingface_hub
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,4 @@ filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523
zarr
lpips==0.1.4
nvidia-ml-py
huggingface_hub
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ all =
zarr
lpips==0.1.4
nvidia-ml-py
huggingface_hub
nibabel =
nibabel
ninja =
Expand Down
82 changes: 43 additions & 39 deletions tests/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import unittest
from typing import TYPE_CHECKING
import os
import sys
import re
import copy

Expand All @@ -33,6 +34,8 @@
has_torchvision = True
else:
torchvision, has_torchvision = optional_import("torchvision")

has_hf_modules = ("huggingface_hub" in sys.modules and "huggingface_hub.utils._errors" in sys.modules)

# from torchvision.models import ResNet50_Weights, resnet50

Expand Down Expand Up @@ -202,46 +205,47 @@ def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape
pretrained_net = model(**cp_input_param)
assert (equal_state_dict(net.state_dict(), pretrained_net.state_dict()))

# True flag
cp_input_param["pretrained"] = True
resnet_depth = int(re.search(r"resnet(\d+)", model.__name__).group(1))

# Duplicate. see monai/networks/nets/resnet.py
def get_medicalnet_pretrained_resnet_args(resnet_depth: int) :
"""
Return correct shortcut_type and bias_downsample for pretrained MedicalNet weights according to rensnet depth
"""
# After testing
# False: 10, 50, 101, 152, 200
# Any: 18, 34
bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34
shortcut_type = "A" if resnet_depth in [18, 34] else "B"
return bias_downsample, shortcut_type

bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth)

# With orig. test cases
if (input_param.get("spatial_dims", 3) == 3 and
input_param.get("n_input_channels", 3)==1 and
input_param.get("feed_forward", True) is False and
input_param.get("shortcut_type", "B") == shortcut_type and
(input_param.get("bias_downsample", True) == bool(bias_downsample) if bias_downsample != -1 else True)
):
model(**cp_input_param)
else:
with self.assertRaises(NotImplementedError):
if has_hf_modules:
# True flag
cp_input_param["pretrained"] = True
resnet_depth = int(re.search(r"resnet(\d+)", model.__name__).group(1))

# Duplicate. see monai/networks/nets/resnet.py
def get_medicalnet_pretrained_resnet_args(resnet_depth: int) :
"""
Return correct shortcut_type and bias_downsample for pretrained MedicalNet weights according to rensnet depth
"""
# After testing
# False: 10, 50, 101, 152, 200
# Any: 18, 34
bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34
shortcut_type = "A" if resnet_depth in [18, 34] else "B"
return bias_downsample, shortcut_type

bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth)

# With orig. test cases
if (input_param.get("spatial_dims", 3) == 3 and
input_param.get("n_input_channels", 3)==1 and
input_param.get("feed_forward", True) is False and
input_param.get("shortcut_type", "B") == shortcut_type and
(input_param.get("bias_downsample", True) == bool(bias_downsample) if bias_downsample != -1 else True)
):
model(**cp_input_param)

# forcing MedicalNet pretrained download for 3D tests cases
cp_input_param["n_input_channels"] = 1
cp_input_param["feed_forward"] = False
cp_input_param["shortcut_type"] = shortcut_type
cp_input_param["bias_downsample"] = bool(bias_downsample) if bias_downsample!=-1 else True
if cp_input_param.get("spatial_dims", 3)==3:
pretrained_net = model(**cp_input_param).to(device)
medicalnet_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device = device)
medicalnet_state_dict = {key.replace("module.", ""): value for key, value in medicalnet_state_dict.items()}
assert(equal_state_dict(pretrained_net.state_dict(), medicalnet_state_dict))
else:
with self.assertRaises(NotImplementedError):
model(**cp_input_param)

# forcing MedicalNet pretrained download for 3D tests cases
cp_input_param["n_input_channels"] = 1
cp_input_param["feed_forward"] = False
cp_input_param["shortcut_type"] = shortcut_type
cp_input_param["bias_downsample"] = bool(bias_downsample) if bias_downsample!=-1 else True
if cp_input_param.get("spatial_dims", 3)==3:
pretrained_net = model(**cp_input_param).to(device)
medicalnet_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device = device)
medicalnet_state_dict = {key.replace("module.", ""): value for key, value in medicalnet_state_dict.items()}
assert(equal_state_dict(pretrained_net.state_dict(), medicalnet_state_dict))

# clean
os.remove(tmp_ckpt_filename)
Expand Down

0 comments on commit 3735733

Please sign in to comment.