diff --git a/models/vista3d/configs/inference.json b/models/vista3d/configs/inference.json index 92f2c0dd..2990da83 100644 --- a/models/vista3d/configs/inference.json +++ b/models/vista3d/configs/inference.json @@ -15,7 +15,7 @@ "output_dtype": "$np.float32", "output_postfix": "trans", "separate_folder": true, - "input_dict": "${'image': '/data/Task09_Spleen/imagesTr/spleen_10.nii.gz', 'label_prompt': [3]}", + "input_dict": "${'image': '/workspace/Task03_Liver/imagesTr_decompress/liver_100.nii', 'label_prompt': [1]}", "everything_labels": "$list(set([i+1 for i in range(132)]) - set([2,16,18,20,21,23,24,25,26,27,128,129,130,131,132]))", "metadata_path": "$@bundle_root + '/configs/metadata.json'", "metadata": "$json.loads(pathlib.Path(@metadata_path).read_text())", @@ -54,6 +54,7 @@ { "_target_": "LoadImaged", "keys": "@image_key", + "reader": "NibabelGPUReader", "image_only": true }, { diff --git a/models/vista3d/scripts/evaluator.py b/models/vista3d/scripts/evaluator.py index f20261b3..3b75b391 100644 --- a/models/vista3d/scripts/evaluator.py +++ b/models/vista3d/scripts/evaluator.py @@ -22,6 +22,7 @@ from monai.utils import ForwardMode, IgniteInfo, RankFilter, min_version, optional_import from monai.utils.enums import CommonKeys as Keys from torch.utils.data import DataLoader +from .warmup import warm_up rearrange, _ = optional_import("einops", name="rearrange") @@ -133,6 +134,7 @@ def __init__( self.inferer = SimpleInferer() if inferer is None else inferer self.hyper_kwargs = hyper_kwargs self.logger.addFilter(RankFilter()) + warm_up() def transform_points(self, point, affine): """transform point to the coordinates of the transformed image diff --git a/models/vista3d/scripts/warmup.py b/models/vista3d/scripts/warmup.py new file mode 100644 index 00000000..c0bbbdcb --- /dev/null +++ b/models/vista3d/scripts/warmup.py @@ -0,0 +1,16 @@ +import tempfile +import cupy as cp +import kvikio + +def warm_up(): + a = cp.arange(100) + with tempfile.NamedTemporaryFile(delete=False) as tmp_file: + tmp_file_name = tmp_file.name + f = kvikio.CuFile(tmp_file_name, "w") + # Write whole array to file + f.write(a) + f.close() + + b = cp.empty_like(a) + f = kvikio.CuFile(tmp_file_name, "r") + f.read(b)