From eb83d8cce6e3a35291b154d045ddc7e2f06d8c78 Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Wed, 14 Aug 2024 11:55:36 -0700 Subject: [PATCH] Config-based prediction with Xarray-based output format (#132) * use callback to write prediction embeddings * moving over the script to compute infection score from contrastive_update * delete unused stem module * organize scripts and CLIs for contrastive phenotyping * add dependencies for prediction * export embedding dataset reader function * add more plots to script * use real paths in predict config * do not require seaborn and umap-learn for base install * use relative path in example job script * add docstrings for embedding writer and reader * don't assign unused grid object * show time and id as hover data in interactive plot * fix typo * fix script to test data i/o * ignore accidental lightning_logs * add plotly and nbformat to visual dependencies * tweak predict cli example * add another plot type - raw features of random samples * comment on speed of clustermap * add prediction config example to specify log path * simplify env var in job script and match cpu count with config * vectorize string concatenation --------- Co-authored-by: Shalin Mehta --- .gitignore | 5 +- .../fit.yml} | 0 .../fit_slurm.sh} | 0 .../contrastive_cli/plot_embeddings.py | 145 +++++++++++++++ .../contrastive_cli/predict.yml | 50 ++++++ .../contrastive_cli/predict_slurm.sh | 21 +++ .../{ => contrastive_scripts}/demo_fit.py | 0 .../graphs_ConvNeXt_ResNet.py | 0 .../{ => contrastive_scripts}/predict.py | 0 .../predict_infection_score_supervised.py | 166 ++++++++++++++++++ .../contrastive_scripts/profile_dataloader.py | 119 +++++++++++++ .../profile_dataloader.sh | 0 .../training_script.py | 0 .../profile_dataloader.py | 113 ------------ pyproject.toml | 4 +- viscy/data/triplet.py | 8 +- viscy/light/embedding_writer.py | 71 ++++++++ viscy/light/engine.py | 85 +-------- viscy/unet/networks/resnet.py | 30 ---- 19 files changed, 590 insertions(+), 227 deletions(-) rename applications/contrastive_phenotyping/{demo_cli_fit.yml => contrastive_cli/fit.yml} (100%) rename applications/contrastive_phenotyping/{demo_cli_fit_slurm.sh => contrastive_cli/fit_slurm.sh} (100%) create mode 100644 applications/contrastive_phenotyping/contrastive_cli/plot_embeddings.py create mode 100644 applications/contrastive_phenotyping/contrastive_cli/predict.yml create mode 100644 applications/contrastive_phenotyping/contrastive_cli/predict_slurm.sh rename applications/contrastive_phenotyping/{ => contrastive_scripts}/demo_fit.py (100%) rename applications/contrastive_phenotyping/{ => contrastive_scripts}/graphs_ConvNeXt_ResNet.py (100%) rename applications/contrastive_phenotyping/{ => contrastive_scripts}/predict.py (100%) create mode 100644 applications/contrastive_phenotyping/contrastive_scripts/predict_infection_score_supervised.py create mode 100644 applications/contrastive_phenotyping/contrastive_scripts/profile_dataloader.py rename applications/contrastive_phenotyping/{ => contrastive_scripts}/profile_dataloader.sh (100%) rename applications/contrastive_phenotyping/{ => contrastive_scripts}/training_script.py (100%) delete mode 100644 applications/contrastive_phenotyping/profile_dataloader.py create mode 100644 viscy/light/embedding_writer.py delete mode 100644 viscy/unet/networks/resnet.py diff --git a/.gitignore b/.gitignore index d84853e4..c8460759 100644 --- a/.gitignore +++ b/.gitignore @@ -40,4 +40,7 @@ htmlcov/ coverage.xml *.cover .hypothesis/ -.pytest_cache/ \ No newline at end of file +.pytest_cache/ + +#lightning_logs directory +lightning_logs/ \ No newline at end of file diff --git a/applications/contrastive_phenotyping/demo_cli_fit.yml b/applications/contrastive_phenotyping/contrastive_cli/fit.yml similarity index 100% rename from applications/contrastive_phenotyping/demo_cli_fit.yml rename to applications/contrastive_phenotyping/contrastive_cli/fit.yml diff --git a/applications/contrastive_phenotyping/demo_cli_fit_slurm.sh b/applications/contrastive_phenotyping/contrastive_cli/fit_slurm.sh similarity index 100% rename from applications/contrastive_phenotyping/demo_cli_fit_slurm.sh rename to applications/contrastive_phenotyping/contrastive_cli/fit_slurm.sh diff --git a/applications/contrastive_phenotyping/contrastive_cli/plot_embeddings.py b/applications/contrastive_phenotyping/contrastive_cli/plot_embeddings.py new file mode 100644 index 00000000..1721c659 --- /dev/null +++ b/applications/contrastive_phenotyping/contrastive_cli/plot_embeddings.py @@ -0,0 +1,145 @@ +# %% +from pathlib import Path + +import numpy as np +import pandas as pd +import plotly.express as px +import seaborn as sns +from sklearn.preprocessing import StandardScaler +from umap import UMAP + +from viscy.light.embedding_writer import read_embedding_dataset + +# %% +dataset = read_embedding_dataset( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations/predict/2024_02_04-tokenized-drop_path_0_0.zarr" +) +dataset + +# %% +# load all unprojected features: +features = dataset["features"] +# or select a well: +# features = features[features["fov_name"].str.contains("B/4")] +features + +# %% +# examine raw features +random_samples = np.random.randint(0, dataset.sizes["sample"], 700) +# concatenate fov_name, track_id, and t to create a unique sample identifier +sample_id = ( + features["fov_name"][random_samples] + + "-" + + features["track_id"][random_samples].astype(str) + + "-" + + features["t"][random_samples].astype(str) +) +px.imshow( + features.values[random_samples], + labels={ + "x": "feature", + "y": "sample", + "color": "value", + }, # change labels to match our metadata + y=sample_id, + # show fov_name as y-axis +) + +# %% +scaled_features = StandardScaler().fit_transform(features.values) + +umap = UMAP() + +embedding = umap.fit_transform(scaled_features) +features = ( + features.assign_coords(UMAP1=("sample", embedding[:, 0])) + .assign_coords(UMAP2=("sample", embedding[:, 1])) + .set_index(sample=["UMAP1", "UMAP2"], append=True) +) +features + +# %% +sns.scatterplot( + x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8 +) + + +# %% +def load_annotation(da, path, name, categories: dict | None = None): + annotation = pd.read_csv(path) + annotation["fov_name"] = "/" + annotation["fov ID"] + annotation = annotation.set_index(["fov_name", "id"]) + mi = pd.MultiIndex.from_arrays( + [da["fov_name"].values, da["id"].values], names=["fov_name", "id"] + ) + selected = annotation.loc[mi][name] + if categories: + selected = selected.astype("category").cat.rename_categories(categories) + return selected + + +# %% +ann_root = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track" +) + +infection = load_annotation( + features, + ann_root / "tracking_v1_infection.csv", + "infection class", + {0.0: "background", 1.0: "uninfected", 2.0: "infected"}, +) +division = load_annotation( + features, + ann_root / "cell_division_state.csv", + "division", + {0: "non-dividing", 2: "dividing"}, +) + + +# %% +sns.scatterplot(x=features["UMAP1"], y=features["UMAP2"], hue=division, s=7, alpha=0.8) + +# %% +sns.scatterplot(x=features["UMAP1"], y=features["UMAP2"], hue=infection, s=7, alpha=0.8) + +# %% +ax = sns.histplot(x=features["UMAP1"], y=features["UMAP2"], hue=infection, bins=64) +sns.move_legend(ax, loc="lower left") + +# %% +sns.displot( + x=features["UMAP1"], + y=features["UMAP2"], + kind="hist", + col=infection, + bins=64, + cmap="inferno", +) + +# %% +# interactive scatter plot to associate clusters with specific cells + +px.scatter( + data_frame=pd.DataFrame( + {k: v for k, v in features.coords.items() if k != "features"} + ), + x="UMAP1", + y="UMAP2", + color=(infection.astype(str) + " " + division.astype(str)).rename("annotation"), + hover_name="fov_name", + hover_data=["id", "t"], +) + +# %% +# cluster features in heatmap directly +# this is very slow for large datasets even with fastcluster installed +inf_codes = pd.Series(infection.values.codes, name="infection") +lut = dict(zip(inf_codes.unique(), "brw")) +row_colors = inf_codes.map(lut) + +g = sns.clustermap( + scaled_features, row_colors=row_colors.to_numpy(), col_cluster=False, cbar_pos=None +) +g.yaxis.set_ticks([]) +# %% diff --git a/applications/contrastive_phenotyping/contrastive_cli/predict.yml b/applications/contrastive_phenotyping/contrastive_cli/predict.yml new file mode 100644 index 00000000..038cbbc7 --- /dev/null +++ b/applications/contrastive_phenotyping/contrastive_cli/predict.yml @@ -0,0 +1,50 @@ +seed_everything: 42 +trainer: + accelerator: gpu + strategy: auto + devices: auto + num_nodes: 1 + precision: 32-true + callbacks: + - class_path: viscy.light.embedding_writer.EmbeddingWriter + init_args: + output_path: "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations/predict/test_prediction_code.zarr" + # edit the following lines to specify logging path + # - class_path: lightning.pytorch.loggers.TensorBoardLogger + # init_args: + # save_dir: /path/to/save_dir + # version: name-of-experiment + # log_graph: True + inference_mode: true +model: + backbone: convnext_tiny + in_channels: 2 + in_stack_depth: 15 + stem_kernel_size: [5, 4, 4] +data: + data_path: /hpc/projects/virtual_staining/2024_02_04_A549_DENV_ZIKV_timelapse/registered_chunked.zarr + tracks_path: /hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr + source_channel: + - Phase3D + - RFP + z_range: [28, 43] + batch_size: 32 + num_workers: 16 + initial_yx_patch_size: [192, 192] + final_yx_patch_size: [192, 192] + normalizations: + - class_path: viscy.transforms.NormalizeSampled + init_args: + keys: [Phase3D] + level: fov_statistics + subtrahend: mean + divisor: std + - class_path: viscy.transforms.ScaleIntensityRangePercentilesd + init_args: + keys: [RFP] + lower: 50 + upper: 99 + b_min: 0.0 + b_max: 1.0 +return_predictions: false +ckpt_path: /hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations/lightning_logs/tokenized-drop-path-0.0/checkpoints/epoch=96-step=23377.ckpt diff --git a/applications/contrastive_phenotyping/contrastive_cli/predict_slurm.sh b/applications/contrastive_phenotyping/contrastive_cli/predict_slurm.sh new file mode 100644 index 00000000..3f91fc9b --- /dev/null +++ b/applications/contrastive_phenotyping/contrastive_cli/predict_slurm.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +#SBATCH --job-name=contrastive_predict +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:1 +#SBATCH --partition=gpu +#SBATCH --cpus-per-task=16 +#SBATCH --mem-per-cpu=7G +#SBATCH --time=0-01:00:00 + +module load anaconda/2022.05 +# Update to use the actual prefix +conda activate $MYDATA/envs/viscy + +scontrol show job $SLURM_JOB_ID + +# use absolute path in production +config=./predict.yml +cat $config +srun python -m viscy.cli.contrastive_triplet predict -c $config diff --git a/applications/contrastive_phenotyping/demo_fit.py b/applications/contrastive_phenotyping/contrastive_scripts/demo_fit.py similarity index 100% rename from applications/contrastive_phenotyping/demo_fit.py rename to applications/contrastive_phenotyping/contrastive_scripts/demo_fit.py diff --git a/applications/contrastive_phenotyping/graphs_ConvNeXt_ResNet.py b/applications/contrastive_phenotyping/contrastive_scripts/graphs_ConvNeXt_ResNet.py similarity index 100% rename from applications/contrastive_phenotyping/graphs_ConvNeXt_ResNet.py rename to applications/contrastive_phenotyping/contrastive_scripts/graphs_ConvNeXt_ResNet.py diff --git a/applications/contrastive_phenotyping/predict.py b/applications/contrastive_phenotyping/contrastive_scripts/predict.py similarity index 100% rename from applications/contrastive_phenotyping/predict.py rename to applications/contrastive_phenotyping/contrastive_scripts/predict.py diff --git a/applications/contrastive_phenotyping/contrastive_scripts/predict_infection_score_supervised.py b/applications/contrastive_phenotyping/contrastive_scripts/predict_infection_score_supervised.py new file mode 100644 index 00000000..f20901b9 --- /dev/null +++ b/applications/contrastive_phenotyping/contrastive_scripts/predict_infection_score_supervised.py @@ -0,0 +1,166 @@ +from argparse import ArgumentParser +from pathlib import Path +import numpy as np +import os +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +from viscy.data.triplet import TripletDataModule, TripletDataset +import pandas as pd +import warnings + +warnings.filterwarnings( + "ignore", + category=UserWarning, + message="To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).", +) + +# %% Paths and constants +save_dir = ( + "/hpc/mydata/alishba.imran/VisCy/applications/contrastive_phenotyping/embeddings4" +) + +# rechunked data +data_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/2.2-register_annotations/updated_all_annotations.zarr" + +# updated tracking data +tracks_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr" + +source_channel = ["background_mask", "uninfected_mask", "infected_mask"] +z_range = (0, 1) +batch_size = 1 # match the number of fovs being processed such that no data is left +# set to 15 for full, 12 for infected, and 8 for uninfected + +# non-rechunked data +data_path_1 = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr" + +# updated tracking data +tracks_path_1 = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr" + +source_channel_1 = ["Nuclei_prediction_labels"] + + +# %% Define the main function for training +def main(hparams): + # Initialize the data module for prediction, re-do embeddings but with size 224 by 224 + data_module = TripletDataModule( + data_path=data_path, + tracks_path=tracks_path, + source_channel=source_channel, + z_range=z_range, + initial_yx_patch_size=(224, 224), + final_yx_patch_size=(224, 224), + batch_size=batch_size, + num_workers=hparams.num_workers, + ) + + data_module.setup(stage="predict") + + print(f"Total prediction dataset size: {len(data_module.predict_dataset)}") + + dataloader = DataLoader( + data_module.predict_dataset, + batch_size=batch_size, + num_workers=hparams.num_workers, + ) + + # Initialize the second data module for segmentation masks + seg_data_module = TripletDataModule( + data_path=data_path_1, + tracks_path=tracks_path_1, + source_channel=source_channel_1, + z_range=z_range, + initial_yx_patch_size=(224, 224), + final_yx_patch_size=(224, 224), + batch_size=batch_size, + num_workers=hparams.num_workers, + ) + + seg_data_module.setup(stage="predict") + + seg_dataloader = DataLoader( + seg_data_module.predict_dataset, + batch_size=batch_size, + num_workers=hparams.num_workers, + ) + + # Initialize lists to store average values + background_avg = [] + uninfected_avg = [] + infected_avg = [] + + for batch, seg_batch in tqdm( + zip(dataloader, seg_dataloader), + desc="Processing batches", + total=len(data_module.predict_dataset), + ): + anchor = batch["anchor"] + seg_anchor = seg_batch["anchor"].int() + + # Extract the fov_name and id from the batch + fov_name = batch["index"]["fov_name"][0] + cell_id = batch["index"]["id"].item() + + fov_dirs = fov_name.split("/") + # Construct the path to the CSV file + csv_path = os.path.join( + tracks_path, *fov_dirs, f"tracks{fov_name.replace('/', '_')}.csv" + ) + + # Read the CSV file + df = pd.read_csv(csv_path) + + # Find the row with the specified id and extract the track_id + track_id = df.loc[df["id"] == cell_id, "track_id"].values[0] + + # Create a boolean mask where segmentation values are equal to the track_id + mask = seg_anchor == track_id + # mask = (seg_anchor > 0) + + # Find the most frequent non-zero value in seg_anchor + # unique, counts = np.unique(seg_anchor[seg_anchor > 0], return_counts=True) + # most_frequent_value = unique[np.argmax(counts)] + + # # Create a boolean mask where segmentation values are equal to the most frequent value + # mask = (seg_anchor == most_frequent_value) + + # Expand the mask to match the anchor tensor shape + mask = mask.expand(1, 3, 1, 224, 224) + + # Calculate average values for each channel (background, uninfected, infected) using the mask + background_avg.append(anchor[:, 0, :, :, :][mask[:, 0]].mean().item()) + uninfected_avg.append(anchor[:, 1, :, :, :][mask[:, 1]].mean().item()) + infected_avg.append(anchor[:, 2, :, :, :][mask[:, 2]].mean().item()) + + # Convert lists to numpy arrays + background_avg = np.array(background_avg) + uninfected_avg = np.array(uninfected_avg) + infected_avg = np.array(infected_avg) + + print("Average values per cell for each mask calculated.") + print("Background average shape:", background_avg.shape) + print("Uninfected average shape:", uninfected_avg.shape) + print("Infected average shape:", infected_avg.shape) + + # Save the averages as .npy files + np.save(os.path.join(save_dir, "background_avg.npy"), background_avg) + np.save(os.path.join(save_dir, "uninfected_avg.npy"), uninfected_avg) + np.save(os.path.join(save_dir, "infected_avg.npy"), infected_avg) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--backbone", type=str, default="resnet50") + parser.add_argument("--margin", type=float, default=0.5) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--schedule", type=str, default="Constant") + parser.add_argument("--log_steps_per_epoch", type=int, default=10) + parser.add_argument("--embedding_len", type=int, default=256) + parser.add_argument("--max_epochs", type=int, default=100) + parser.add_argument("--accelerator", type=str, default="gpu") + parser.add_argument("--devices", type=int, default=1) + parser.add_argument("--num_nodes", type=int, default=1) + parser.add_argument("--log_every_n_steps", type=int, default=1) + parser.add_argument("--num_workers", type=int, default=8) + args = parser.parse_args() + main(args) diff --git a/applications/contrastive_phenotyping/contrastive_scripts/profile_dataloader.py b/applications/contrastive_phenotyping/contrastive_scripts/profile_dataloader.py new file mode 100644 index 00000000..57fe0a03 --- /dev/null +++ b/applications/contrastive_phenotyping/contrastive_scripts/profile_dataloader.py @@ -0,0 +1,119 @@ +# %% Imports and initialization. +import os +import time +import warnings +from pathlib import Path +from tqdm import tqdm + +from viscy.data.triplet import TripletDataModule +from monai.transforms import NormalizeIntensityd, ScaleIntensityRangePercentilesd + + +# %% Setup parameters for dataloader +# rechunked data +data_path = "/hpc/projects/virtual_staining/2024_02_04_A549_DENV_ZIKV_timelapse/registered_chunked.zarr" + +# updated tracking data +tracks_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr" +source_channel = ["RFP", "Phase3D"] +z_range = (28, 43) +batch_size = 32 +num_workers = 16 + +# Updated normalizations +normalizations = [ + NormalizeIntensityd( + keys=["Phase3D"], + subtrahend=None, + divisor=None, + nonzero=False, + channel_wise=False, + dtype=None, + allow_missing_keys=False, + ), + ScaleIntensityRangePercentilesd( + keys=["RFP"], + lower=50, + upper=99, + b_min=0.0, + b_max=1.0, + clip=False, + relative=False, + channel_wise=False, + dtype=None, + allow_missing_keys=False, + ), +] + + +# %% Initialize the data module +data_module = TripletDataModule( + data_path=data_path, + tracks_path=tracks_path, + source_channel=source_channel, + z_range=z_range, + initial_yx_patch_size=(512, 512), + final_yx_patch_size=(224, 224), + batch_size=batch_size, + num_workers=num_workers, + normalizations=normalizations, +) +# for train and val +data_module.setup("fit") + +print( + f"Total dataset size: {len(data_module.train_dataset) + len(data_module.val_dataset)}" +) +print(f"Training dataset size: {len(data_module.train_dataset)}") +print(f"Validation dataset size: {len(data_module.val_dataset)}") + +# %% Profile the data i/o +num_epochs = 1 +start_time = time.time() +total_bytes_transferred = 0 # Track the total number of bytes transferred + +# Profile the data i/o +for i in range(num_epochs): + # Train dataloader + train_dataloader = data_module.train_dataloader() + train_dataloader = tqdm(train_dataloader, desc=f"Epoch {i+1}/{num_epochs} - Train") + for batch in train_dataloader: + anchor_batch = batch["anchor"] + positive_batch = batch["positive"] + negative_batch = batch["negative"] + total_bytes_transferred += ( + anchor_batch.nbytes + positive_batch.nbytes + negative_batch.nbytes + ) + # print("Anchor batch shape:", anchor_batch.shape) + # print("Positive batch shape:", positive_batch.shape) + # print("Negative batch shape:", negative_batch.shape) + + # Validation dataloader + val_dataloader = data_module.val_dataloader() + val_dataloader = tqdm(val_dataloader, desc=f"Epoch {i+1}/{num_epochs} - Validation") + for batch in val_dataloader: + anchor_batch = batch["anchor"] + positive_batch = batch["positive"] + negative_batch = batch["negative"] + total_bytes_transferred += ( + anchor_batch.nbytes + positive_batch.nbytes + negative_batch.nbytes + ) + # print("Anchor batch shape:", anchor_batch.shape) + # print("Positive batch shape:", positive_batch.shape) + # print("Negative batch shape:", negative_batch.shape) + +end_time = time.time() +elapsed_time = end_time - start_time +data_transfer_speed = (total_bytes_transferred / elapsed_time) / ( + 1024 * 1024 +) # Calculate data transfer speed in MBPS + +print("Anchor batch shape:", anchor_batch.shape) +print("Positive batch shape:", positive_batch.shape) +print("Negative batch shape:", negative_batch.shape) + +print(f"Elapsed time for {num_epochs} iterations: {elapsed_time} seconds") +print(f"Average time per iteration: {elapsed_time/num_epochs} seconds") +print(f"Data transfer speed: {data_transfer_speed} MBPS") + +# %% diff --git a/applications/contrastive_phenotyping/profile_dataloader.sh b/applications/contrastive_phenotyping/contrastive_scripts/profile_dataloader.sh similarity index 100% rename from applications/contrastive_phenotyping/profile_dataloader.sh rename to applications/contrastive_phenotyping/contrastive_scripts/profile_dataloader.sh diff --git a/applications/contrastive_phenotyping/training_script.py b/applications/contrastive_phenotyping/contrastive_scripts/training_script.py similarity index 100% rename from applications/contrastive_phenotyping/training_script.py rename to applications/contrastive_phenotyping/contrastive_scripts/training_script.py diff --git a/applications/contrastive_phenotyping/profile_dataloader.py b/applications/contrastive_phenotyping/profile_dataloader.py deleted file mode 100644 index cad32370..00000000 --- a/applications/contrastive_phenotyping/profile_dataloader.py +++ /dev/null @@ -1,113 +0,0 @@ -# %% Imports and initialization. -import os -import time -import warnings -from pathlib import Path - -import wandb -from tqdm import tqdm - -from viscy.data.hcs import ContrastiveDataModule - -warnings.filterwarnings("ignore") -os.environ["WANDB_DIR"] = f"/hpc/mydata/{os.environ['USER']}/" -data_on_lustre = Path("/hpc/projects/intracellular_dashboard/viral-sensor/") -data_on_vast = Path("/hpc/projects/virtual_staining/viral_sensor_test_dataio/") -wandb.init(project="contrastive_model", entity="alishba_imran-CZ Biohub") - -# %% Method that iterates over two epochs and logs the resource usage. - - -def profile_dataio(top_dir, num_epochs=1): - - channels = 2 - x = 200 - y = 200 - z_range = (0, 10) - batch_size = 16 - base_path = ( - top_dir / "2024_02_04_A549_DENV_ZIKV_timelapse/6-patches/full_patch.zarr" - ) - timesteps_csv_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/6-patches/final_track_timesteps.csv" - - data_module = ContrastiveDataModule( - base_path=base_path, - channels=channels, - x=x, - y=y, - timesteps_csv_path=timesteps_csv_path, - batch_size=batch_size, - num_workers=8, - z_range=z_range, - ) - - # for train and val - data_module.setup() - - print( - f"Total dataset size: {len(data_module.train_dataset) + len(data_module.val_dataset) + len(data_module.test_dataset)}" - ) - print(f"Training dataset size: {len(data_module.train_dataset)}") - print(f"Validation dataset size: {len(data_module.val_dataset)}") - print(f"Test dataset size: {len(data_module.test_dataset)}") - - start_time = time.time() - total_bytes_transferred = 0 # Track the total number of bytes transferred - - # Profile the data i/o - for i in range(num_epochs): - # Train dataloader - train_dataloader = data_module.train_dataloader() - train_dataloader = tqdm( - train_dataloader, desc=f"Epoch {i+1}/{num_epochs} - Train" - ) - for batch in train_dataloader: - anchor_batch, positive_batch, negative_batch = batch - total_bytes_transferred += ( - anchor_batch.nbytes + positive_batch.nbytes + negative_batch.nbytes - ) - # print("Anchor batch shape:", anchor_batch.shape) - # print("Positive batch shape:", positive_batch.shape) - # print("Negative batch shape:", negative_batch.shape) - - # Validation dataloader - val_dataloader = data_module.val_dataloader() - val_dataloader = tqdm( - val_dataloader, desc=f"Epoch {i+1}/{num_epochs} - Validation" - ) - for batch in val_dataloader: - anchor_batch, positive_batch, negative_batch = batch - total_bytes_transferred += ( - anchor_batch.nbytes + positive_batch.nbytes + negative_batch.nbytes - ) - # print("Anchor batch shape:", anchor_batch.shape) - # print("Positive batch shape:", positive_batch.shape) - # print("Negative batch shape:", negative_batch.shape) - - end_time = time.time() - elapsed_time = end_time - start_time - data_transfer_speed = (total_bytes_transferred / elapsed_time) / ( - 1024 * 1024 - ) # Calculate data transfer speed in MBPS - - print("Anchor batch shape:", anchor_batch.shape) - print("Positive batch shape:", positive_batch.shape) - print("Negative batch shape:", negative_batch.shape) - - print(f"Elapsed time for {num_epochs} iterations: {elapsed_time} seconds") - print(f"Average time per iteration: {elapsed_time/num_epochs} seconds") - print(f"Data transfer speed: {data_transfer_speed} MBPS") - - -# %% Testing the data i/o with data stored on Vast -print(f"Profiling data i/o with data stored on VAST\n{data_on_vast}\n") -profile_dataio(data_on_vast) - - -# %% Testing the data i/o with data stored on Lustre -print(f"Profiling data i/o with data stored on Lustre\n{data_on_lustre}\n") - -profile_dataio(data_on_lustre) - -# %% -wandb.finish() diff --git a/pyproject.toml b/pyproject.toml index fdc3c722..57248fa4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "scikit-image", "matplotlib>=3.9.0", "numpy<2", + "xarray", ] dynamic = ["version"] @@ -29,8 +30,9 @@ metrics = [ "scikit-learn>=1.1.3", "torchmetrics[detection]>=1.3.1", "ptflops>=0.7", + "umap-learn", ] -visual = ["ipykernel", "graphviz", "torchview"] +visual = ["ipykernel", "graphviz", "torchview", "seaborn", "plotly", "nbformat"] dev = [ "pytest", "pytest-cov", diff --git a/viscy/data/triplet.py b/viscy/data/triplet.py index 45aab1fe..4e056851 100644 --- a/viscy/data/triplet.py +++ b/viscy/data/triplet.py @@ -14,6 +14,8 @@ _logger = logging.getLogger("lightning.pytorch") +INDEX_COLUMNS = ["fov_name", "track_id", "t", "id", "parent_track_id", "parent_id"] + def _scatter_channels( channel_names: list[str], patch: Tensor, norm_meta: NormMeta | None @@ -185,11 +187,7 @@ def __getitem__(self, index: int) -> TripletSample: patch=anchor_patch, norm_meta=anchor_norm, ) - - sample = { - "anchor": anchor_patch, - "index": anchor_row[["fov_name", "id"]].to_dict(), - } + sample = {"anchor": anchor_patch, "index": anchor_row[INDEX_COLUMNS].to_dict()} if self.fit: sample.update( { diff --git a/viscy/light/embedding_writer.py b/viscy/light/embedding_writer.py new file mode 100644 index 00000000..badd26ce --- /dev/null +++ b/viscy/light/embedding_writer.py @@ -0,0 +1,71 @@ +import logging +from pathlib import Path +from typing import Literal, Sequence + +import pandas as pd +import torch +from lightning.pytorch import LightningModule, Trainer +from lightning.pytorch.callbacks import BasePredictionWriter +from xarray import Dataset, open_zarr + +from viscy.data.triplet import INDEX_COLUMNS + +__all__ = ["read_embedding_dataset", "EmbeddingWriter"] +_logger = logging.getLogger("lightning.pytorch") + + +def read_embedding_dataset(path: Path) -> Dataset: + """Read the embedding dataset written by the EmbeddingWriter callback. + + :param Path path: Path to the zarr store. + :return Dataset: Xarray dataset with features and projections. + """ + return open_zarr(path).set_index(sample=INDEX_COLUMNS) + + +class EmbeddingWriter(BasePredictionWriter): + """Callback to write embeddings to a zarr store in an Xarray-compatible format. + + :param Path output_path: Path to the zarr store. + :param Literal["batch", "epoch", "batch_and_epoch"] write_interval: + When to write the embeddings, defaults to "epoch". + """ + + def __init__( + self, + output_path: Path, + write_interval: Literal["batch", "epoch", "batch_and_epoch"] = "epoch", + ): + super().__init__(write_interval) + self.output_path = Path(output_path) + + def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + if self.output_path.exists(): + raise FileExistsError(f"Output path {self.output_path} already exists.") + _logger.debug(f"Writing embeddings to {self.output_path}") + + def write_on_epoch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + predictions: Sequence[dict], + batch_indices: Sequence[int], + ) -> None: + features = torch.cat([p["features"] for p in predictions], dim=0) + projections = torch.cat([p["projections"] for p in predictions], dim=0) + index = pd.MultiIndex.from_frame( + pd.concat([pd.DataFrame(p["index"]) for p in predictions]) + ) + dataset = Dataset( + { + "features": (("sample", "features"), features.cpu().numpy()), + "projections": ( + ("sample", "projections"), + projections.cpu().numpy(), + ), + }, + coords={"sample": index}, + ).reset_index("sample") + _logger.debug(f"Wrtiting predictions dataset:\n{dataset}") + zarr_store = dataset.to_zarr(self.output_path, mode="w") + zarr_store.close() diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 6e658100..4c474c4a 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -3,7 +3,6 @@ from typing import Literal, Sequence, Union import numpy as np -import pandas as pd import torch import torch.nn.functional as F from imageio import imwrite @@ -585,10 +584,6 @@ def __init__( embedding_len: int = 256, predict: bool = False, drop_path_rate: float = 0.2, - tracks_path: str = "data/tracks", - features_output_path: str = "", - projections_output_path: str = "", - metadata_output_path: str = "", ) -> None: super().__init__() self.loss_function = loss_function @@ -605,10 +600,6 @@ def __init__( self.test_metrics = [] self.processed_order = [] self.predictions = [] - self.tracks_path = tracks_path - self.features_output_path = features_output_path - self.projections_output_path = projections_output_path - self.metadata_output_path = metadata_output_path self.model = ContrastiveEncoder( backbone=backbone, in_channels=in_channels, @@ -742,73 +733,13 @@ def configure_optimizers(self): optimizer = Adam(self.parameters(), lr=self.lr) return optimizer - def on_predict_start(self) -> None: - if not ( - self.features_output_path - and self.projections_output_path - and self.metadata_output_path - ): - raise ValueError( - "Output paths for features, projections, and metadata must be provided." - ) - - def predict_step(self, batch: TripletSample, batch_idx, dataloader_idx=0): + def predict_step( + self, batch: TripletSample, batch_idx, dataloader_idx=0 + ) -> dict[str, Tensor | dict]: """Prediction step for extracting embeddings.""" features, projections = self.model(batch["anchor"]) - index = batch["index"] - self.predictions.append( - (features.cpu().numpy(), projections.cpu().numpy(), index) - ) - return features, projections, index - - def on_predict_epoch_end(self) -> None: - combined_features = [] - combined_projections = [] - accumulated_data = [] - - for features, projections, index in self.predictions: - combined_features.extend(features) - combined_projections.extend(projections) - - fov_names = index["fov_name"] - cell_ids = index["id"].cpu().numpy() - - for fov_name, cell_id in zip(fov_names, cell_ids): - parts = fov_name.split("/") - row = parts[1] - column = parts[2] - fov = parts[3] - - csv_path = os.path.join( - self.tracks_path, - row, - column, - fov, - f"tracks_{row}_{column}_{fov}.csv", - ) - - df = pd.read_csv(csv_path) - - track_id = df[df["id"] == cell_id]["track_id"].values[0] - timestep = df[df["id"] == cell_id]["t"].values[0] - - accumulated_data.append((row, column, fov, track_id, timestep)) - - combined_features = np.array(combined_features) - combined_projections = np.array(combined_projections) - - np.save(self.features_output_path, combined_features) - np.save(self.projections_output_path, combined_projections) - - rows, columns, fovs, track_ids, timesteps = zip(*accumulated_data) - df = pd.DataFrame( - { - "Row": rows, - "Column": columns, - "FOV": fovs, - "Cell ID": track_ids, - "Timestep": timesteps, - } - ) - - df.to_csv(self.metadata_output_path, index=False) + return { + "features": features, + "projections": projections, + "index": batch["index"], + } diff --git a/viscy/unet/networks/resnet.py b/viscy/unet/networks/resnet.py deleted file mode 100644 index a34f7271..00000000 --- a/viscy/unet/networks/resnet.py +++ /dev/null @@ -1,30 +0,0 @@ -from torch import Tensor, nn - - -class resnetStem(nn.Module): - """Stem for ResNet networks to handle 3D multi-channel input.""" - - # Currently identical to UNeXt2Stem, but could be different in the future. This module is unused for now. - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: tuple[int, int, int], - in_stack_depth: int, - ) -> None: - super().__init__() - ratio = in_stack_depth // kernel_size[0] - self.conv = nn.Conv3d( - in_channels=in_channels, - out_channels=out_channels // ratio, - kernel_size=kernel_size, - stride=kernel_size, - ) - - def forward(self, x: Tensor): - x = self.conv(x) - b, c, d, h, w = x.shape - # project Z/depth into channels - # return a view when possible (contiguous) - return x.reshape(b, c * d, h, w)