diff --git a/vista3d/NVIDIA OneWay Noncommercial License.txt b/vista3d/NVIDIA OneWay Noncommercial License.txt index 048a4a8..58d9eca 100644 --- a/vista3d/NVIDIA OneWay Noncommercial License.txt +++ b/vista3d/NVIDIA OneWay Noncommercial License.txt @@ -27,8 +27,8 @@ Works are “made available” under this license by including in or with the Wo 4. Disclaimer of Warranty. -THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF -MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. +THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 5. Limitation of Liability. diff --git a/vista3d/README.md b/vista3d/README.md index 801b062..ebcbbce 100644 --- a/vista3d/README.md +++ b/vista3d/README.md @@ -77,8 +77,8 @@ pip install -r requirements.txt Download the [model checkpoint](https://drive.google.com/file/d/1eLIxQwnxGsjggxiVjdcAyNvJ5DYtqmdc/view?usp=sharing) and save it at ./models/model.pt. ### Inference -The [NIM Demo (VISTA3D NVIDIA Inference Microservices)](https://build.nvidia.com/nvidia/vista-3d) does not support medical data upload due to legal concerns. -We provide scripts for inference locally. The automatic segmentation label definition can be found at [label_dict](./data/jsons/label_dict.json). +The [NIM Demo (VISTA3D NVIDIA Inference Microservices)](https://build.nvidia.com/nvidia/vista-3d) does not support medical data upload due to legal concerns. +We provide scripts for inference locally. The automatic segmentation label definition can be found at [label_dict](./data/jsons/label_dict.json). 1. We provide the `infer.py` script and its light-weight front-end `debugger.py`. User can directly lauch a local interface for both automatic and interactive segmentation. ``` python -m scripts.debugger run @@ -154,7 +154,7 @@ We provide scripts to run SAM2 evaluation. Modify SAM2 source code to support ba async_loading_frames=async_loading_frames, ) if z_slice is not None: - images = images[z_slice] + images = images[z_slice] ``` Run evaluation ``` diff --git a/vista3d/requirements.txt b/vista3d/requirements.txt index 46efb3f..c978d66 100644 --- a/vista3d/requirements.txt +++ b/vista3d/requirements.txt @@ -14,4 +14,4 @@ einops==0.6.1 ml-collections timm pytorch-ignite -tensorboardX \ No newline at end of file +tensorboardX diff --git a/vista3d/scripts/validation/val_multigpu_sam2_point_iterative.py b/vista3d/scripts/validation/val_multigpu_sam2_point_iterative.py index a44d0d8..f47699a 100644 --- a/vista3d/scripts/validation/val_multigpu_sam2_point_iterative.py +++ b/vista3d/scripts/validation/val_multigpu_sam2_point_iterative.py @@ -16,15 +16,16 @@ import json import logging import os -import sys import random +import sys from datetime import timedelta from typing import Optional, Sequence, Union -from matplotlib import pyplot as plt + import monai -from PIL import Image +import numpy as np import torch import torch.distributed as dist +from matplotlib import pyplot as plt from monai import transforms from monai.apps.auto3dseg.auto_runner import logger from monai.auto3dseg.utils import datafold_read @@ -33,13 +34,13 @@ from monai.data import DataLoader, partition_dataset from monai.metrics import compute_dice from monai.utils import set_determinism -import numpy as np -import pdb +from PIL import Image from sam2.build_sam import build_sam2_video_predictor from scipy.ndimage import binary_erosion from ..train import CONFIG + def save_nifti_frames_to_jpg(data, output_folder=None): data = torch.squeeze(data) # Ensure output folder exists @@ -47,7 +48,7 @@ def save_nifti_frames_to_jpg(data, output_folder=None): os.makedirs(output_folder, exist_ok=True) # Loop through each frame in the 3D image for i in range(data.shape[2]): - save_name = os.path.join(output_folder, f'{i + 1:04d}.jpg') + save_name = os.path.join(output_folder, f"{i + 1:04d}.jpg") if os.path.exists(save_name): continue frame = data[:, :, i] @@ -59,26 +60,26 @@ def save_nifti_frames_to_jpg(data, output_folder=None): return output_folder + def plot(pred, label, point, point_label, name): - fig, (ax1, ax2) = plt.subplots(1,2) + fig, (ax1, ax2) = plt.subplots(1, 2) ax1.imshow(pred) - for p,l in zip(point, point_label): - ax1.scatter(p[0],p[1],c='r' if l==1 else 'g') + for p, l in zip(point, point_label): + ax1.scatter(p[0], p[1], c="r" if l == 1 else "g") ax2.imshow(label) - for p,l in zip(point, point_label): - ax2.scatter(p[0],p[1],c='r' if l==1 else 'g') + for p, l in zip(point, point_label): + ax2.scatter(p[0], p[1], c="r" if l == 1 else "g") plt.show() plt.savefig(name) plt.close() + def get_points_from_label(labels, index=1): - """ Sample the starting point - label [1, H, W, ...] + """Sample the starting point + label [1, H, W, ...] """ plabels = labels == index - plabels = monai.transforms.utils.get_largest_connected_component_mask( - plabels - ) + plabels = monai.transforms.utils.get_largest_connected_component_mask(plabels) plabelpoints = torch.nonzero(plabels) pmean = plabelpoints.float().mean(0) pdis = ((plabelpoints - pmean) ** 2).sum(-1) @@ -86,27 +87,32 @@ def get_points_from_label(labels, index=1): point = plabelpoints[sorted_indices[0]] return point + def get_center_points(plabelpoints): pmean = plabelpoints.float().mean() pdis = ((plabelpoints - pmean) ** 2).sum(-1) _, sorted_indices = torch.sort(pdis) return plabelpoints[sorted_indices[0]] + def get_points_from_false_pred(pred, gt, num_point=1): - """ sample points from false negative and positive. - """ + """sample points from false negative and positive.""" # Define the structuring element (kernel) of size 5x5 structuring_element = np.ones((3, 3), dtype=np.uint8) # handle false positive fp_mask = torch.logical_and(torch.logical_not(gt), pred) # sample from largest connected components did not show much difference. # fp_mask = monai.transforms.utils.get_largest_connected_component_mask(fp_mask) - eroded_image = binary_erosion(fp_mask.cpu().numpy(), structure=structuring_element).astype(np.uint8) + eroded_image = binary_erosion( + fp_mask.cpu().numpy(), structure=structuring_element + ).astype(np.uint8) plabelpoints = torch.nonzero(torch.from_numpy(eroded_image)) - # handle false negative + # handle false negative fn_mask = torch.logical_and(torch.logical_not(pred), gt) # fn_mask = monai.transforms.utils.get_largest_connected_component_mask(fn_mask) - eroded_image = binary_erosion(fn_mask.cpu().numpy(), structure=structuring_element).astype(np.uint8) + eroded_image = binary_erosion( + fn_mask.cpu().numpy(), structure=structuring_element + ).astype(np.uint8) nlabelpoints = torch.nonzero(torch.from_numpy(eroded_image)) _point = [] _label = [] @@ -116,7 +122,7 @@ def get_points_from_false_pred(pred, gt, num_point=1): if len(p) > 0: p = get_center_points(p) _point.append([p[1], p[0]]) - _label.append(l) + _label.append(l) if num_point == 3: if len(nlabelpoints) > 0: ppoint = get_center_points(nlabelpoints) @@ -125,15 +131,16 @@ def get_points_from_false_pred(pred, gt, num_point=1): if len(plabelpoints) > 0: npoint = get_center_points(plabelpoints) _point.append([npoint[1], npoint[0]]) - _label.append(0) + _label.append(0) p = nlabelpoints if len(nlabelpoints) > len(plabelpoints) else plabelpoints l = 1 if len(nlabelpoints) > len(plabelpoints) else 0 if len(p) > 0: p = random.choice(p) _point.append([p[1], p[0]]) - _label.append(l) + _label.append(l) return _point, _label + def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): # Initialize distributed and scale parameters based on GPU memory if torch.cuda.device_count() > 1: @@ -148,11 +155,11 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): torch.device(f"cuda:{os.environ['LOCAL_RANK']}") if world_size > 1 else torch.device("cuda:0") - ) + ) torch.cuda.set_device(device) # use bfloat16 torch.autocast(device_type=str(device), dtype=torch.bfloat16).__enter__() - + if torch.cuda.get_device_properties(0).major >= 8: # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) torch.backends.cuda.matmul.allow_tf32 = True @@ -182,20 +189,22 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): use_center = parser.get_parsed_content("use_center", default=True) output_path = parser.get_parsed_content("output_path") dataset_name = parser.get_parsed_content("dataset_name", default=None) - # remove slices without foreground + # remove slices without foreground saliency = parser.get_parsed_content("saliency", default=False) start_file = parser.get_parsed_content("start_file", default=0) end_file = parser.get_parsed_content("end_file", default=-1) # merge tumors into organs to avoid some confusions, e.g. merge liver tumor into liver thus - # may improve liver seg results. Not showing large difference. + # may improve liver seg results. Not showing large difference. merge_tumors = parser.get_parsed_content("merge_tumors", default=False) MAX_ITER = parser.get_parsed_content("max_iter", default=1) - log_output_file = parser.get_parsed_content("log_output_file").replace(".log", f"_{start_file}_{end_file}_s{saliency}.log") - parser.update(pairs={'log_output_file': log_output_file}) + log_output_file = parser.get_parsed_content("log_output_file").replace( + ".log", f"_{start_file}_{end_file}_s{saliency}.log" + ) + parser.update(pairs={"log_output_file": log_output_file}) if not os.path.exists(output_path): os.makedirs(output_path) - + if label_set is None: label_mapping = parser.get_parsed_content( "label_mapping", default="./data/jsons/label_mappings.json" @@ -214,13 +223,13 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): "log_output_file" ) # remove rank filter - CONFIG["handlers"]["file"].pop('filters') - CONFIG["handlers"]["console"].pop('filters') + CONFIG["handlers"]["file"].pop("filters") + CONFIG["handlers"]["console"].pop("filters") logging.config.dictConfig(CONFIG) logging.getLogger("torch.distributed.distributed_c10d").setLevel(logging.WARNING) logger.debug(f"Number of GPUs: {torch.cuda.device_count()}") logger.debug(f"World_size: {world_size}") - + if five_fold: train_files, val_files = datafold_read( datalist=data_list_file_path, @@ -264,7 +273,7 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): if end_file == -1: end_file = len(process_files) process_files = process_files[start_file:end_file] - logger.info(f'Working on files from {start_file} to {end_file}: {process_files}') + logger.info(f"Working on files from {start_file} to {end_file}: {process_files}") for i in range(len(process_files)): if ( @@ -291,8 +300,6 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): shuffle=False, ) - - predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint) predictor = predictor.to(device) # need to uncomment this one if run first time and failed. comment out after first run. @@ -321,12 +328,16 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): val_filename = val_data["image"].meta["filename_or_obj"][0] _index += 1 name_parts = val_filename.split("/") - video_dir=os.path.join(output_path, dataset_name, - name_parts[-2]+ "_" + name_parts[-1].split(".")[0]) + video_dir = os.path.join( + output_path, + dataset_name, + name_parts[-2] + "_" + name_parts[-1].split(".")[0], + ) save_nifti_frames_to_jpg(val_data["image"], video_dir) # scan all the JPEG frame names in this directory frame_names = [ - p for p in os.listdir(video_dir) + p + for p in os.listdir(video_dir) if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] ] frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) @@ -335,33 +346,39 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): predictor.reset_state(inference_state) # loop through the label_set exist_label = sorted(val_data["label"].unique().numpy().tolist()) - exist_label = list(set(exist_label).intersection(label_set)) + exist_label = list(set(exist_label).intersection(label_set)) if len(exist_label) == 1: continue for i in range(1, len(exist_label)): label_index = exist_label[i] - label = torch.squeeze((val_data["label"] == label_index).to(torch.uint8)) + label = torch.squeeze( + (val_data["label"] == label_index).to(torch.uint8) + ) if merge_tumors: # Can only be used for Task3, Task7. Disabled by default. - logger.debug('merging tumors') + logger.debug("merging tumors") if label_index == 1: - label = label + torch.squeeze((val_data["label"] == 2).to(torch.uint8)) - # remove + label = label + torch.squeeze( + (val_data["label"] == 2).to(torch.uint8) + ) + # remove if saliency: - print('removing label without foreground') + print("removing label without foreground") z_slice = label.sum(0).sum(0) > 0 - label = label[:,:,z_slice] - inference_state = predictor.init_state(video_path=video_dir, z_slice=z_slice) + label = label[:, :, z_slice] + inference_state = predictor.init_state( + video_path=video_dir, z_slice=z_slice + ) predictor.reset_state(inference_state) - for idx in range(max_iters): + for idx in range(max_iters): if idx == 0: # select initial points from the center of ROI point = get_points_from_label(label) _point = [[point[1], point[0]]] _label = [1] ann_frame_idx = point[-1] - ann_obj_id = 1 + ann_obj_id = 1 for rounds in range(4): points = np.array(_point, dtype=np.float32) labels = np.array(_label, np.int32) @@ -374,8 +391,10 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): labels=labels, ) pred = (out_mask_logits[0] > 0.0).cpu()[0] - gt = label[:,:,point[-1]] - new_point, new_label = get_points_from_false_pred(pred, gt, num_point=1) + gt = label[:, :, point[-1]] + new_point, new_label = get_points_from_false_pred( + pred, gt, num_point=1 + ) if len(new_label) == 0: break _point.extend(new_point) @@ -384,9 +403,11 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): else: ann_frame_idx = lowerest_dice_index # select points from the slice with smallest dice - new_point, new_label = get_points_from_false_pred(pred[..., ann_frame_idx], - label[..., ann_frame_idx], - num_point=3) + new_point, new_label = get_points_from_false_pred( + pred[..., ann_frame_idx], + label[..., ann_frame_idx], + num_point=3, + ) # plot(pred[..., ann_frame_idx], label[..., ann_frame_idx], new_point, new_label, f'{idx}_{rounds}.png') if len(new_point) > 0: points = np.array(new_point, dtype=np.float32) @@ -402,20 +423,35 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): print("cannot find new points! End the iteration") break - video_segments = {} # video_segments contains the per-frame segmentation results - for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state, reverse=False): + video_segments = ( + {} + ) # video_segments contains the per-frame segmentation results + for ( + out_frame_idx, + out_obj_ids, + out_mask_logits, + ) in predictor.propagate_in_video(inference_state, reverse=False): video_segments[out_frame_idx] = { out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids) } - for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state, reverse=True): + for ( + out_frame_idx, + out_obj_ids, + out_mask_logits, + ) in predictor.propagate_in_video(inference_state, reverse=True): video_segments[out_frame_idx] = { out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids) } #### - pred = [video_segments[i][ann_obj_id][0] for i in sorted(list(video_segments.keys()))] - pred = torch.from_numpy(np.stack(pred).transpose(1,2,0)).to(torch.uint8) + pred = [ + video_segments[i][ann_obj_id][0] + for i in sorted(list(video_segments.keys())) + ] + pred = torch.from_numpy(np.stack(pred).transpose(1, 2, 0)).to( + torch.uint8 + ) # compute per-frame dice lowerest_dice = 1000 @@ -423,27 +459,31 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): for d in range(pred.shape[-1]): if torch.sum(label[..., d]) > 0: pt_frame_dice = compute_dice( - y_pred=pred[..., d].unsqueeze(0).unsqueeze(0), + y_pred=pred[..., d].unsqueeze(0).unsqueeze(0), y=label[..., d].unsqueeze(0).unsqueeze(0), - include_background=False + include_background=False, ) if pt_frame_dice < lowerest_dice: lowerest_dice_index = d lowerest_dice = pt_frame_dice - max_error_pixel = torch.abs(pred[..., d] - label[..., d]).sum() + max_error_pixel = torch.abs( + pred[..., d] - label[..., d] + ).sum() elif pt_frame_dice == lowerest_dice: - error_pixel = torch.abs(pred[..., d] - label[..., d]).sum() + error_pixel = torch.abs( + pred[..., d] - label[..., d] + ).sum() if max_error_pixel < error_pixel: lowerest_dice_index = d max_error_pixel = error_pixel # compute volume dice pt_volume_dice = compute_dice( - y_pred=pred.unsqueeze(0).unsqueeze(0), - y=label.unsqueeze(0).unsqueeze(0), - include_background=False - ) - + y_pred=pred.unsqueeze(0).unsqueeze(0), + y=label.unsqueeze(0).unsqueeze(0), + include_background=False, + ) + print(f"iter {idx}, pt_volume_dice", pt_volume_dice) metric[_index - 1, i - 1, idx] = pt_volume_dice @@ -500,7 +540,7 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): fire.Fire() -##### functions to plot +##### functions to plot # import torch # import numpy as np # import matplotlib.pyplot as plt @@ -515,7 +555,7 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): # print('total clip number', len(dataclip)) # else: # dataclip = torch.load(clip)['metric'].numpy() - + # if type(noclip) is list: # datanoclip = [] # for i in noclip: @@ -552,18 +592,18 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): # end = dataclip.shape[1] # for i in range(start, end): # axes[0].plot(x_values, mean_dataclip[i], label=f'{classes[i]}') -# axes[0].fill_between(x_values, mean_dataclip[i] - 1.96 * sem_datanoclip[i], -# mean_dataclip[i] + 1.96 * sem_datanoclip[i], +# axes[0].fill_between(x_values, mean_dataclip[i] - 1.96 * sem_datanoclip[i], +# mean_dataclip[i] + 1.96 * sem_datanoclip[i], # alpha=0.2) # axes[0].set_xlabel('Number of annotated slices') # axes[0].set_ylabel('Mean Dice Value') # axes[0].legend() # axes[0].grid(True) # axes[0].set_title(f'{dataset} results with background removal') - + # axes[1].plot(x_values, mean_datanoclip[i], label=f'{classes[i]}') -# axes[1].fill_between(x_values, mean_datanoclip[i] - 1.96 * sem_datanoclip[i], -# mean_datanoclip[i] + 1.96 * sem_datanoclip[i], +# axes[1].fill_between(x_values, mean_datanoclip[i] - 1.96 * sem_datanoclip[i], +# mean_datanoclip[i] + 1.96 * sem_datanoclip[i], # alpha=0.2) # axes[1].set_xlabel('Number of annotated slices without background removal') @@ -572,6 +612,6 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): # axes[1].grid(True) # axes[1].set_title(f'{dataset} results without background removal') -# plt.show() +# plt.show() # #task07 -# plot_func('validation_auto_clipTask07.pt', 'validation_auto_Task07.pt', ['Task07 Pancreas', 'Task07 Pancreas Tumor'], 'MSD Task07') \ No newline at end of file +# plot_func('validation_auto_clipTask07.pt', 'validation_auto_Task07.pt', ['Task07 Pancreas', 'Task07 Pancreas Tumor'], 'MSD Task07')