diff --git a/models/vista3d/configs/inference.json b/models/vista3d/configs/inference.json index 92f2c0dd..5193ec14 100644 --- a/models/vista3d/configs/inference.json +++ b/models/vista3d/configs/inference.json @@ -15,7 +15,8 @@ "output_dtype": "$np.float32", "output_postfix": "trans", "separate_folder": true, - "input_dict": "${'image': '/data/Task09_Spleen/imagesTr/spleen_10.nii.gz', 'label_prompt': [3]}", + "sw_batch_size": 10, + "input_dict": "${'image': '/home/liubin/data/trt_2408/dataset/Task09_Spleen/imagesTr/spleen_10.nii.gz'}", "everything_labels": "$list(set([i+1 for i in range(132)]) - set([2,16,18,20,21,23,24,25,26,27,128,129,130,131,132]))", "metadata_path": "$@bundle_root + '/configs/metadata.json'", "metadata": "$json.loads(pathlib.Path(@metadata_path).read_text())", @@ -40,7 +41,6 @@ 1.5, 1.5 ], - "sw_batch_size": 1, "patch_size": [ 128, 128, @@ -105,9 +105,19 @@ "dtype": "$torch.float32" } ], + "range_preprocessing": { + "_target_": "Range", + "name": "preprocessing", + "recursive": true + }, + "range_postprocessing": { + "_target_": "Range", + "name": "postprocessing", + "recursive": true + }, "preprocessing": { "_target_": "Compose", - "transforms": "$@preprocessing_transforms " + "transforms": "$@range_preprocessing(@preprocessing_transforms)" }, "dataset": { "_target_": "Dataset", @@ -128,48 +138,57 @@ "sw_batch_size": "@sw_batch_size", "use_point_window": "@use_point_window" }, + "postprocessing_transforms": [ + { + "_target_": "ToDeviced", + "keys": "pred", + "device": "cpu", + "_disabled_": true + }, + { + "_target_": "monai.apps.vista3d.transforms.VistaPostTransformd", + "keys": "pred" + }, + { + "_target_": "Invertd", + "keys": "pred", + "transform": "$copy.deepcopy(@preprocessing)", + "orig_keys": "@image_key", + "nearest_interp": true, + "to_tensor": true + }, + { + "_target_": "Lambdad", + "func": "$lambda x: torch.nan_to_num(x, nan=255)", + "keys": "pred" + }, + { + "_target_": "SaveImaged", + "keys": "pred", + "resample": false, + "output_dir": "@output_dir", + "output_ext": "@output_ext", + "output_dtype": "@output_dtype", + "output_postfix": "@output_postfix", + "separate_folder": "@separate_folder" + } + ], "postprocessing": { "_target_": "Compose", - "transforms": [ - { - "_target_": "ToDeviced", - "keys": "pred", - "device": "cpu", - "_disabled_": true - }, - { - "_target_": "monai.apps.vista3d.transforms.VistaPostTransformd", - "keys": "pred" - }, - { - "_target_": "Invertd", - "keys": "pred", - "transform": "$copy.deepcopy(@preprocessing)", - "orig_keys": "@image_key", - "nearest_interp": true, - "to_tensor": true - }, - { - "_target_": "Lambdad", - "func": "$lambda x: torch.nan_to_num(x, nan=255)", - "keys": "pred" - }, - { - "_target_": "SaveImaged", - "keys": "pred", - "resample": false, - "output_dir": "@output_dir", - "output_ext": "@output_ext", - "output_dtype": "@output_dtype", - "output_postfix": "@output_postfix", - "separate_folder": "@separate_folder" - } - ] + "transforms": "$@range_postprocessing(@postprocessing_transforms)" }, "handlers": [ { "_target_": "StatsHandler", "iteration_log": false + }, + { + "_target_": "RangeHandler", + "events": "ITERATION" + }, + { + "_target_": "RangeHandler", + "events": "BATCH" } ], "checkpointloader": { diff --git a/models/vista3d/scripts/inferer.py b/models/vista3d/scripts/inferer.py index 345b1a89..31be6dda 100644 --- a/models/vista3d/scripts/inferer.py +++ b/models/vista3d/scripts/inferer.py @@ -15,6 +15,7 @@ import torch from monai.apps.vista3d.inferer import point_based_window_inferer from monai.inferers import Inferer, SlidingWindowInfererAdapt +from monai.utils import Range from torch import Tensor @@ -80,38 +81,40 @@ def __call__( device = inputs[0].device else: device = inputs.device - val_outputs = point_based_window_inferer( - inputs=inputs, - roi_size=self.roi_size, - sw_batch_size=self.sw_batch_size, - transpose=True, - with_coord=True, - predictor=network, - mode="gaussian", - sw_device=device, - device=device, - overlap=self.overlap, - point_coords=point_coords, - point_labels=point_labels, - class_vector=class_vector, - prompt_class=prompt_class, - prev_mask=prev_mask, - labels=labels, - label_set=label_set, - ) + with Range("SW_PointInfer"): + val_outputs = point_based_window_inferer( + inputs=inputs, + roi_size=self.roi_size, + sw_batch_size=self.sw_batch_size, + transpose=True, + with_coord=True, + predictor=network, + mode="gaussian", + sw_device=device, + device=device, + overlap=self.overlap, + point_coords=point_coords, + point_labels=point_labels, + class_vector=class_vector, + prompt_class=prompt_class, + prev_mask=prev_mask, + labels=labels, + label_set=label_set, + ) else: - val_outputs = SlidingWindowInfererAdapt( - roi_size=self.roi_size, sw_batch_size=self.sw_batch_size, with_coord=True - )( - inputs, - network, - transpose=True, - point_coords=point_coords, - point_labels=point_labels, - class_vector=class_vector, - prompt_class=prompt_class, - prev_mask=prev_mask, - labels=labels, - label_set=label_set, - ) + with Range("SW_Infer"): + val_outputs = SlidingWindowInfererAdapt( + roi_size=self.roi_size, sw_batch_size=self.sw_batch_size, with_coord=True + )( + inputs, + network, + transpose=True, + point_coords=point_coords, + point_labels=point_labels, + class_vector=class_vector, + prompt_class=prompt_class, + prev_mask=prev_mask, + labels=labels, + label_set=label_set, + ) return val_outputs