Skip to content

Commit

Permalink
Update inference.json
Browse files Browse the repository at this point in the history
481-Rotation of output images in prostate_mri_anatomy

- The transformations applied to the test data during preprocessing are not reversed during postprocessing. Thus, the resulting predictions have a different rotation compared to the original images and labels. I added an 'invertd()' function that reverses all the traceable transformations in the preprocessing.
- Running the inference on a PC without a GPU leads to an error because the handlers function 'CheckpointLoader' attempts to load the model onto the GPU. I set the map_location argument to torch.device(device) so that it automatically loads the model onto the CPU when a GPU is not available.
- Additionally, I changed some of the config directories as they were referring to the validation dataset rather than the test dataset.

Signed-off-by: FaresAlMohamad <[email protected]>
  • Loading branch information
FaresAlMohamad authored Aug 5, 2023
1 parent a4a08f8 commit e327914
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions models/prostate_mri_anatomy/configs/inference.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
],
"bundle_root": "/workspace/data/prostate_mri_anatomy",
"output_dir": "$@bundle_root + '/eval'",
"dataset_dir": "/workspace/data/prostate158/prostate158_train/",
"datalist": "$list(@dataset_dir + pd.read_csv(@dataset_dir + 'valid.csv').t2)",
"dataset_dir": "/workspace/data/prostate158/prostate158_test/",
"datalist": "$list(@dataset_dir + pd.read_csv(@dataset_dir + 'test.csv').t2)",
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
"network_def": {
"_target_": "UNet",
Expand Down Expand Up @@ -114,6 +114,15 @@
2
]
},
{
"_target_": "Invertd",
"keys": "pred",
"transform": "@preprocessing",
"orig_keys": "image",
"meta_key_postfix": "meta_dict",
"nearest_interp": false,
"to_tensor": true
},
{
"_target_": "SaveImaged",
"keys": "pred",
Expand All @@ -127,6 +136,10 @@
{
"_target_": "CheckpointLoader",
"load_path": "$@bundle_root + '/models/model.pt'",
"map_location": {
"_target_": "torch.device",
"device": "@device"
},
"load_dict": {
"model": "@network"
}
Expand Down

0 comments on commit e327914

Please sign in to comment.