From e3b2dbb78d6f7079ac5a261c2c88783047dc0251 Mon Sep 17 00:00:00 2001 From: Luca Date: Tue, 9 Jul 2024 17:28:48 +0200 Subject: [PATCH] debug train sam without encoder on mitottomo --- development/predict_3d_model_with_lucchi.py | 49 ++-- development/train_3d_model_with_lucchi.py | 13 +- ...in_3d_model_with_lucchi_without_decoder.py | 252 ++++++++++++++++++ 3 files changed, 286 insertions(+), 28 deletions(-) create mode 100644 development/train_3d_model_with_lucchi_without_decoder.py diff --git a/development/predict_3d_model_with_lucchi.py b/development/predict_3d_model_with_lucchi.py index fe4fe250..4279b378 100644 --- a/development/predict_3d_model_with_lucchi.py +++ b/development/predict_3d_model_with_lucchi.py @@ -96,31 +96,32 @@ def run_semantic_segmentation_3d( assert os.path.exists(image_path), image_path # Perform segmentation only on the semantic class - for i, (semantic_class_name, _) in enumerate(semantic_class_map.items()): - if is_multiclass: - semantic_class_name = "all" - if i > 0: # We only perform segmentation for multiclass once. - continue + # for i, (semantic_class_name, _) in enumerate(semantic_class_map.items()): + # if is_multiclass: + # semantic_class_name = "all" + # if i > 0: # We only perform segmentation for multiclass once. + # continue + semantic_class_name = "all" #since we only perform segmentation for multiclass # We skip the images that already have been segmented - image_name = os.path.splitext(image_name)[0] + ".tif" - prediction_path = os.path.join(prediction_dir, semantic_class_name, image_name) - if os.path.exists(prediction_path): - continue + image_name = os.path.splitext(image_name)[0] + ".tif" + prediction_path = os.path.join(prediction_dir, "all", image_name) + if os.path.exists(prediction_path): + continue - if image_key is None: - image = imageio.imread(image_path) - else: - with open_file(image_path, "r") as f: - image = f[image_key][:] + if image_key is None: + image = imageio.imread(image_path) + else: + with open_file(image_path, "r") as f: + image = f[image_key][:] - # create the prediction folder - os.makedirs(os.path.join(prediction_dir, semantic_class_name), exist_ok=True) + # create the prediction folder + os.makedirs(os.path.join(prediction_dir, semantic_class_name), exist_ok=True) - _run_semantic_segmentation_for_image_3d( - model=model, image=image, prediction_path=prediction_path, - patch_shape=patch_shape, halo=halo, - ) + _run_semantic_segmentation_for_image_3d( + model=model, image=image, prediction_path=prediction_path, + patch_shape=patch_shape, halo=halo, + ) def transform_labels(y): @@ -144,7 +145,9 @@ def predict(args): checkpoint = torch.load(cp_path, map_location=device) # # Load the state dictionary from the checkpoint - model.load_state_dict(checkpoint['model'].state_dict()) + for k, v in checkpoint.items(): + print("keys", k) + model.load_state_dict(checkpoint['model_state']) #.state_dict() model.eval() data_paths = glob(os.path.join(args.input_path, "**/*test.h5"), recursive=True) @@ -169,7 +172,7 @@ def main(): ) parser.add_argument("--patch_shape", type=int, nargs=3, default=(32, 512, 512), help="Patch shape for data loading (3D tuple)") parser.add_argument("--n_iterations", type=int, default=10, help="Number of training iterations") - parser.add_argument("--n_classes", type=int, default=2, help="Number of classes to predict") + parser.add_argument("--n_classes", type=int, default=3, help="Number of classes to predict") parser.add_argument("--batch_size", type=int, default=1, help="Batch size") parser.add_argument("--num_workers", type=int, default=4, help="num_workers") parser.add_argument( @@ -177,7 +180,7 @@ def main(): help="The filepath to where the logs and the checkpoints will be saved." ) parser.add_argument( - "--checkpoint_path", "-c", default="/scratch-grete/usr/nimlufre/micro-sam3d/checkpoints/3d-sam-lucchi-train/", + "--checkpoint_path", "-c", default="/scratch-grete/usr/nimlufre/micro-sam3d/checkpoints/3d-sam-vitb-masamhyp-lucchi", help="The filepath to where the logs and the checkpoints will be saved." ) diff --git a/development/train_3d_model_with_lucchi.py b/development/train_3d_model_with_lucchi.py index 4b41dd1a..9ff888c1 100644 --- a/development/train_3d_model_with_lucchi.py +++ b/development/train_3d_model_with_lucchi.py @@ -124,7 +124,6 @@ def train_on_lucchi(args): save_root = args.save_root - device = "cuda" if torch.cuda.is_available() else "cpu" if args.without_lora: sam_3d = get_sam_3d_model( @@ -135,7 +134,10 @@ def train_on_lucchi(args): device, n_classes=n_classes, image_size=patch_shape[1], model_type=model_type, lora_rank=4) train_loader, val_loader = get_loaders(input_path=input_path, patch_shape=patch_shape) - optimizer = torch.optim.AdamW(sam_3d.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), weight_decay=0.1) + #optimizer = torch.optim.AdamW(sam_3d.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), weight_decay=0.1) + optimizer = torch.optim.Adam(sam_3d.parameters(), lr=1e-5) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=15, verbose=True) + #masam no scheduler trainer = SemanticSamTrainer( @@ -146,6 +148,7 @@ def train_on_lucchi(args): train_loader=train_loader, val_loader=val_loader, optimizer=optimizer, + lr_scheduler=scheduler, device=device, compile_model=False, save_root=save_root, @@ -170,15 +173,15 @@ def main(): parser.add_argument("--n_epochs", type=int, default=400, help="Number of training epochs") parser.add_argument("--n_classes", type=int, default=3, help="Number of classes to predict") - parser.add_argument("--batch_size", "-bs", type=int, default=3, help="Batch size") + parser.add_argument("--batch_size", "-bs", type=int, default=1, help="Batch size") # masam 3 parser.add_argument("--num_workers", type=int, default=4, help="num_workers") - parser.add_argument("--learning_rate", type=float, default=0.0008, help="base learning rate") + parser.add_argument("--learning_rate", type=float, default=1e-5, help="base learning rate") # MASAM 0.0008 parser.add_argument( "--save_root", "-s", default="/scratch-grete/usr/nimlufre/micro-sam3d", help="The filepath to where the logs and the checkpoints will be saved." ) parser.add_argument( - "--exp_name", default="vitb_3d_lora4", + "--exp_name", default="vitb_3d_lora4-microsam-hypam-lucchi", help="The filepath to where the logs and the checkpoints will be saved." ) diff --git a/development/train_3d_model_with_lucchi_without_decoder.py b/development/train_3d_model_with_lucchi_without_decoder.py new file mode 100644 index 00000000..664fc6db --- /dev/null +++ b/development/train_3d_model_with_lucchi_without_decoder.py @@ -0,0 +1,252 @@ +import numpy as np +from glob import glob +import h5py +from micro_sam.training import train_sam, default_sam_dataset +from torch_em.data.sampler import MinInstanceSampler +from torch_em.segmentation import get_data_loader +import torch +import torch_em +import os +import argparse +from skimage.measure import regionprops + + +def get_rois_coordinates_skimage(file, label_key, min_shape, euler_threshold=None, min_amount_pixels=None): + """ + Calculates the average coordinates for each unique label in a 3D label image using skimage.regionprops. + + Args: + file (h5py.File): Handle to the open HDF5 file. + label_key (str): Key for the label data within the HDF5 file. + min_shape (tuple): A tuple representing the minimum size for each dimension of the ROI. + euler_threshold (int, optional): The Euler number threshold. If provided, only regions with the specified Euler number will be considered. + min_amount_pixels (int, optional): The minimum amount of pixels. If provided, only regions with at least this many pixels will be considered. + + Returns: + dict or None: A dictionary mapping unique labels to lists of average coordinates for each dimension, or None if no labels are found. + """ + + label_data = file[label_key] + label_shape = label_data.shape + + # Ensure data type is suitable for regionprops (usually uint labels) + # if label_data.dtype != np.uint: + # label_data = label_data.astype(np.uint).value + + # Find connected regions (objects) using regionprops + regions = regionprops(label_data) + + # Check if any regions were found + if not regions: + return None + + label_extents = {} + for region in regions: + if euler_threshold is not None: + if region.euler_number != euler_threshold: + continue + if min_amount_pixels is not None: + if region["area"] < min_amount_pixels: + continue + + # # Extract relevant information for ROI calculation + label = region.label # Get the label value + min_coords = region.bbox[:3] # Minimum coordinates (excluding intensity channel) + max_coords = region.bbox[3:6] # Maximum coordinates (excluding intensity channel) + + # Clip coordinates and create ROI extent (similar to previous approach) + clipped_min_coords = np.clip(min_coords, 0, label_shape[0] - min_shape[0]) + clipped_max_coords = np.clip(max_coords, min_shape[1], label_shape[1]) + roi_extent = tuple(slice(min_val, min_val + min_shape[dim]) for dim, (min_val, max_val) in enumerate(zip(clipped_min_coords, clipped_max_coords))) + + # Check for labels within the ROI extent (new part) + roi_data = file[label_key][roi_extent] + amount_label_pixels = np.count_nonzero(roi_data) + if amount_label_pixels < 100 or amount_label_pixels < min_amount_pixels: # Check for any non-zero values (labels) + continue # Skip this ROI if no labels present + + label_extents[label] = roi_extent + + return label_extents + + +def get_data_paths_and_rois(data_dir, min_shape, + data_format="*.h5", + image_key="raw", + label_key_mito="labels/mitochondria", + label_key_cristae="labels/cristae", + with_thresholds=True): + + data_paths = glob(os.path.join(data_dir, "**", data_format), recursive=True) + rois_list = [] + new_data_paths = [] # one data path for each ROI + + for data_path in data_paths: + try: + # Open the HDF5 file in read-only mode + with h5py.File(data_path, "r") as f: + # Check for existence of image and label datasets (considering key flexibility) + if image_key not in f or (label_key_mito is not None and label_key_mito not in f): + print(f"Warning: Key(s) missing in {data_path}. Skipping {image_key}") + continue + + #label_data_mito = f[label_key_mito][()] if label_key_mito is not None else None + + # Extract ROIs (assuming ndim of label data is the same as image data) + if with_thresholds: + rois = get_rois_coordinates_skimage(f, label_key_mito, min_shape, min_amount_pixels=100) # euler_threshold=1, + else: + rois = get_rois_coordinates_skimage(f, label_key_mito, min_shape, euler_threshold=None, min_amount_pixels=None) + for label_id, roi in rois.items(): + rois_list.append(roi) + new_data_paths.append(data_path) + except OSError: + print(f"Error accessing file: {data_path}. Skipping...") + + return new_data_paths, rois_list + + +def split_data_paths_to_dict(data_paths, rois_list, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1): + """ + Splits data paths and ROIs into training, validation, and testing sets without shuffling. + + Args: + data_paths (list): List of paths to all HDF5 files. + rois_list (list): List of ROIs corresponding to the data paths. + train_ratio (float, optional): Proportion of data for training (0.0-1.0) (default: 0.8). + val_ratio (float, optional): Proportion of data for validation (0.0-1.0) (default: 0.1). + test_ratio (float, optional): Proportion of data for testing (0.0-1.0) (default: 0.1). + + Returns: + tuple: A tuple containing two dictionaries: + - data_split: Dictionary containing "train", "val", and "test" keys with data paths. + - rois_split: Dictionary containing "train", "val", and "test" keys with corresponding ROIs. + """ + + if train_ratio + val_ratio + test_ratio != 1.0: + raise ValueError("Sum of train, validation, and test ratios must equal 1.0.") + num_data = len(data_paths) + if rois_list is not None: + if len(rois_list) != num_data: + raise ValueError(f"Length of data paths and number of ROIs in the dictionary must match: len rois {len(rois_list)}, len data_paths {len(data_paths)}") + + train_size = int(num_data * train_ratio) + val_size = int(num_data * val_ratio) # Optional validation set + test_size = num_data - train_size - val_size + + data_split = { + "train": data_paths[:train_size], + "val": data_paths[train_size:train_size+val_size], + "test": data_paths[train_size+val_size:] + } + + if rois_list is not None: + rois_split = { + "train": rois_list[:train_size], + "val": rois_list[train_size:train_size+val_size], + "test": rois_list[train_size+val_size:] + } + + return data_split, rois_split + else: + return data_split + + +def get_data_paths(data_dir, data_format="*.h5"): + data_paths = glob(os.path.join(data_dir, "**", data_format), recursive=True) + return data_paths + + +def train(args): + n_workers = 4 if torch.cuda.is_available() else 1 + device = "cuda" if torch.cuda.is_available() else "cpu" + data_dir = args.input_path + with_rois = True if args.without_rois is False else False + patch_shape = args.patch_shape + label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=True) + ndim = 3 + + if with_rois: + data_paths, rois_dict = get_data_paths_and_rois(data_dir, min_shape=patch_shape, with_thresholds=True) + data, rois_dict = split_data_paths_to_dict(data_paths, rois_dict, train_ratio=.8, val_ratio=0.2, test_ratio=0) + else: + data_paths = get_data_paths(data_dir) + data = split_data_paths_to_dict(data_paths, rois_list=None, train_ratio=.5, val_ratio=0.5, test_ratio=0) + #path = "/scratch-emmy/projects/nim00007/fruit-fly-data/cambridge_data/parker_s2_soma_roi_z472-600_y795-1372_x1122-1687_clahed.zarr" + label_key = "labels/mitochondria" # "./annotations1.tif" + + # train_ds = default_sam_dataset( + # raw_paths=data["train"][0], raw_key="raw", + # label_paths=data["train"][0], label_key=label_key, + # patch_shape=args.patch_shape, with_segmentation_decoder=False, + # sampler=MinInstanceSampler(3), + # #rois=rois_dict["train"], + # n_samples=200, + # ) + # train_loader = get_data_loader(train_ds, shuffle=True, batch_size=2) + + # val_ds = default_sam_dataset( + # raw_paths=data["val"][0], raw_key="raw", + # label_paths=data["val"][0], label_key=label_key, + # patch_shape=args.patch_shape, with_segmentation_decoder=False, + # sampler=MinInstanceSampler(3), + # #rois=rois_dict["val"], + # is_train=False, n_samples=25, + # ) + # val_loader = get_data_loader(val_ds, shuffle=True, batch_size=1) + train_loader = torch_em.default_segmentation_loader( + raw_paths=data["train"], raw_key="raw", + label_paths=data["train"], label_key="labels/mitochondria", + patch_shape=patch_shape, ndim=ndim, batch_size=1, + label_transform=label_transform, num_workers=n_workers, + ) + val_loader = torch_em.default_segmentation_loader( + raw_paths=data["train"], raw_key="raw", + label_paths=data["val"], label_key="labels/mitochondria", + patch_shape=patch_shape, ndim=ndim, batch_size=1, + label_transform=label_transform, num_workers=n_workers, + ) + + train_sam( + name="nucleus_model", model_type="vit_b", + train_loader=train_loader, val_loader=val_loader, + n_epochs=50, n_objects_per_batch=10, + with_segmentation_decoder=False, + save_root=args.save_root, + ) + + +def main(): + parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.") + parser.add_argument( + "--input_path", "-i", default="/scratch-grete/projects/nim00007/data/mitochondria/cooper/mito_tomo/", + help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded." + ) + parser.add_argument( + "--model_type", "-m", default="vit_b", + help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h." + ) + parser.add_argument("--without_lora", action="store_true", help="Whether to use LoRA for finetuning SAM for semantic segmentation.") + parser.add_argument("--patch_shape", type=int, nargs=3, default=(32, 512, 512), help="Patch shape for data loading (3D tuple)") + + parser.add_argument("--n_epochs", type=int, default=400, help="Number of training epochs") + parser.add_argument("--n_classes", type=int, default=3, help="Number of classes to predict") + parser.add_argument("--batch_size", "-bs", type=int, default=1, help="Batch size") # masam 3 + parser.add_argument("--num_workers", type=int, default=4, help="num_workers") + parser.add_argument("--learning_rate", type=float, default=1e-5, help="base learning rate") # MASAM 0.0008 + parser.add_argument( + "--save_root", "-s", default="/scratch-grete/usr/nimlufre/micro-sam_training_on_mitotomo", + help="The filepath to where the logs and the checkpoints will be saved." + ) + parser.add_argument( + "--exp_name", default="vitb_3d_lora4-microsam-hypam-lucchi", + help="The filepath to where the logs and the checkpoints will be saved." + ) + parser.add_argument("--without_rois", type=bool, default=True, help="Train without Regions Of Interest (ROI)") + + args = parser.parse_args() + train(args) + + +if __name__ == "__main__": + main()