diff --git a/models/prostate_mri_anatomy/configs/inference.json b/models/prostate_mri_anatomy/configs/inference.json index efdc02c5..92b7f5ee 100644 --- a/models/prostate_mri_anatomy/configs/inference.json +++ b/models/prostate_mri_anatomy/configs/inference.json @@ -5,8 +5,8 @@ ], "bundle_root": "/workspace/data/prostate_mri_anatomy", "output_dir": "$@bundle_root + '/eval'", - "dataset_dir": "/workspace/data/prostate158/prostate158_train/", - "datalist": "$list(@dataset_dir + pd.read_csv(@dataset_dir + 'valid.csv').t2)", + "dataset_dir": "/workspace/data/prostate158/prostate158_test/", + "datalist": "$list(@dataset_dir + pd.read_csv(@dataset_dir + 'test.csv').t2)", "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", "network_def": { "_target_": "UNet", @@ -114,6 +114,15 @@ 2 ] }, + { + "_target_": "Invertd", + "keys": "pred", + "transform": "@preprocessing", + "orig_keys": "image", + "meta_key_postfix": "meta_dict", + "nearest_interp": false, + "to_tensor": true + }, { "_target_": "SaveImaged", "keys": "pred", @@ -127,6 +136,10 @@ { "_target_": "CheckpointLoader", "load_path": "$@bundle_root + '/models/model.pt'", + "map_location": { + "_target_": "torch.device", + "device": "@device" + }, "load_dict": { "model": "@network" }