Skip to content

Commit

Permalink
Config-based prediction with Xarray-based output format (#132)
Browse files Browse the repository at this point in the history
* use callback to write prediction embeddings

* moving over the script to compute infection score from contrastive_update

* delete unused stem module

* organize scripts and CLIs for contrastive phenotyping

* add dependencies for prediction

* export embedding dataset reader function

* add more plots to script

* use real paths in predict config

* do not require seaborn and umap-learn for base install

* use relative path in example job script

* add docstrings for embedding writer and reader

* don't assign unused grid object

* show time and id as hover data in interactive plot

* fix typo

* fix script to test data i/o

* ignore accidental lightning_logs

* add plotly and nbformat to visual dependencies

* tweak predict cli example

* add another plot type - raw features of random samples

* comment on speed of clustermap

* add prediction config example to specify log path

* simplify env var in job script
and match cpu count with config

* vectorize string concatenation

---------

Co-authored-by: Shalin Mehta <[email protected]>
  • Loading branch information
2 people authored and edyoshikun committed Aug 15, 2024
1 parent 9678911 commit eb83d8c
Show file tree
Hide file tree
Showing 19 changed files with 590 additions and 227 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,7 @@ htmlcov/
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
.pytest_cache/

#lightning_logs directory
lightning_logs/
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# %%
from pathlib import Path

import numpy as np
import pandas as pd
import plotly.express as px
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from umap import UMAP

from viscy.light.embedding_writer import read_embedding_dataset

# %%
dataset = read_embedding_dataset(
"/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations/predict/2024_02_04-tokenized-drop_path_0_0.zarr"
)
dataset

# %%
# load all unprojected features:
features = dataset["features"]
# or select a well:
# features = features[features["fov_name"].str.contains("B/4")]
features

# %%
# examine raw features
random_samples = np.random.randint(0, dataset.sizes["sample"], 700)
# concatenate fov_name, track_id, and t to create a unique sample identifier
sample_id = (
features["fov_name"][random_samples]
+ "-"
+ features["track_id"][random_samples].astype(str)
+ "-"
+ features["t"][random_samples].astype(str)
)
px.imshow(
features.values[random_samples],
labels={
"x": "feature",
"y": "sample",
"color": "value",
}, # change labels to match our metadata
y=sample_id,
# show fov_name as y-axis
)

# %%
scaled_features = StandardScaler().fit_transform(features.values)

umap = UMAP()

embedding = umap.fit_transform(scaled_features)
features = (
features.assign_coords(UMAP1=("sample", embedding[:, 0]))
.assign_coords(UMAP2=("sample", embedding[:, 1]))
.set_index(sample=["UMAP1", "UMAP2"], append=True)
)
features

# %%
sns.scatterplot(
x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8
)


# %%
def load_annotation(da, path, name, categories: dict | None = None):
annotation = pd.read_csv(path)
annotation["fov_name"] = "/" + annotation["fov ID"]
annotation = annotation.set_index(["fov_name", "id"])
mi = pd.MultiIndex.from_arrays(
[da["fov_name"].values, da["id"].values], names=["fov_name", "id"]
)
selected = annotation.loc[mi][name]
if categories:
selected = selected.astype("category").cat.rename_categories(categories)
return selected


# %%
ann_root = Path(
"/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track"
)

infection = load_annotation(
features,
ann_root / "tracking_v1_infection.csv",
"infection class",
{0.0: "background", 1.0: "uninfected", 2.0: "infected"},
)
division = load_annotation(
features,
ann_root / "cell_division_state.csv",
"division",
{0: "non-dividing", 2: "dividing"},
)


# %%
sns.scatterplot(x=features["UMAP1"], y=features["UMAP2"], hue=division, s=7, alpha=0.8)

# %%
sns.scatterplot(x=features["UMAP1"], y=features["UMAP2"], hue=infection, s=7, alpha=0.8)

# %%
ax = sns.histplot(x=features["UMAP1"], y=features["UMAP2"], hue=infection, bins=64)
sns.move_legend(ax, loc="lower left")

# %%
sns.displot(
x=features["UMAP1"],
y=features["UMAP2"],
kind="hist",
col=infection,
bins=64,
cmap="inferno",
)

# %%
# interactive scatter plot to associate clusters with specific cells

px.scatter(
data_frame=pd.DataFrame(
{k: v for k, v in features.coords.items() if k != "features"}
),
x="UMAP1",
y="UMAP2",
color=(infection.astype(str) + " " + division.astype(str)).rename("annotation"),
hover_name="fov_name",
hover_data=["id", "t"],
)

# %%
# cluster features in heatmap directly
# this is very slow for large datasets even with fastcluster installed
inf_codes = pd.Series(infection.values.codes, name="infection")
lut = dict(zip(inf_codes.unique(), "brw"))
row_colors = inf_codes.map(lut)

g = sns.clustermap(
scaled_features, row_colors=row_colors.to_numpy(), col_cluster=False, cbar_pos=None
)
g.yaxis.set_ticks([])
# %%
50 changes: 50 additions & 0 deletions applications/contrastive_phenotyping/contrastive_cli/predict.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
seed_everything: 42
trainer:
accelerator: gpu
strategy: auto
devices: auto
num_nodes: 1
precision: 32-true
callbacks:
- class_path: viscy.light.embedding_writer.EmbeddingWriter
init_args:
output_path: "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations/predict/test_prediction_code.zarr"
# edit the following lines to specify logging path
# - class_path: lightning.pytorch.loggers.TensorBoardLogger
# init_args:
# save_dir: /path/to/save_dir
# version: name-of-experiment
# log_graph: True
inference_mode: true
model:
backbone: convnext_tiny
in_channels: 2
in_stack_depth: 15
stem_kernel_size: [5, 4, 4]
data:
data_path: /hpc/projects/virtual_staining/2024_02_04_A549_DENV_ZIKV_timelapse/registered_chunked.zarr
tracks_path: /hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr
source_channel:
- Phase3D
- RFP
z_range: [28, 43]
batch_size: 32
num_workers: 16
initial_yx_patch_size: [192, 192]
final_yx_patch_size: [192, 192]
normalizations:
- class_path: viscy.transforms.NormalizeSampled
init_args:
keys: [Phase3D]
level: fov_statistics
subtrahend: mean
divisor: std
- class_path: viscy.transforms.ScaleIntensityRangePercentilesd
init_args:
keys: [RFP]
lower: 50
upper: 99
b_min: 0.0
b_max: 1.0
return_predictions: false
ckpt_path: /hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations/lightning_logs/tokenized-drop-path-0.0/checkpoints/epoch=96-step=23377.ckpt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/bin/bash

#SBATCH --job-name=contrastive_predict
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1
#SBATCH --partition=gpu
#SBATCH --cpus-per-task=16
#SBATCH --mem-per-cpu=7G
#SBATCH --time=0-01:00:00

module load anaconda/2022.05
# Update to use the actual prefix
conda activate $MYDATA/envs/viscy

scontrol show job $SLURM_JOB_ID

# use absolute path in production
config=./predict.yml
cat $config
srun python -m viscy.cli.contrastive_triplet predict -c $config
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from argparse import ArgumentParser
from pathlib import Path
import numpy as np
import os
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from viscy.data.triplet import TripletDataModule, TripletDataset
import pandas as pd
import warnings

warnings.filterwarnings(
"ignore",
category=UserWarning,
message="To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).",
)

# %% Paths and constants
save_dir = (
"/hpc/mydata/alishba.imran/VisCy/applications/contrastive_phenotyping/embeddings4"
)

# rechunked data
data_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/2.2-register_annotations/updated_all_annotations.zarr"

# updated tracking data
tracks_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr"

source_channel = ["background_mask", "uninfected_mask", "infected_mask"]
z_range = (0, 1)
batch_size = 1 # match the number of fovs being processed such that no data is left
# set to 15 for full, 12 for infected, and 8 for uninfected

# non-rechunked data
data_path_1 = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr"

# updated tracking data
tracks_path_1 = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr"

source_channel_1 = ["Nuclei_prediction_labels"]


# %% Define the main function for training
def main(hparams):
# Initialize the data module for prediction, re-do embeddings but with size 224 by 224
data_module = TripletDataModule(
data_path=data_path,
tracks_path=tracks_path,
source_channel=source_channel,
z_range=z_range,
initial_yx_patch_size=(224, 224),
final_yx_patch_size=(224, 224),
batch_size=batch_size,
num_workers=hparams.num_workers,
)

data_module.setup(stage="predict")

print(f"Total prediction dataset size: {len(data_module.predict_dataset)}")

dataloader = DataLoader(
data_module.predict_dataset,
batch_size=batch_size,
num_workers=hparams.num_workers,
)

# Initialize the second data module for segmentation masks
seg_data_module = TripletDataModule(
data_path=data_path_1,
tracks_path=tracks_path_1,
source_channel=source_channel_1,
z_range=z_range,
initial_yx_patch_size=(224, 224),
final_yx_patch_size=(224, 224),
batch_size=batch_size,
num_workers=hparams.num_workers,
)

seg_data_module.setup(stage="predict")

seg_dataloader = DataLoader(
seg_data_module.predict_dataset,
batch_size=batch_size,
num_workers=hparams.num_workers,
)

# Initialize lists to store average values
background_avg = []
uninfected_avg = []
infected_avg = []

for batch, seg_batch in tqdm(
zip(dataloader, seg_dataloader),
desc="Processing batches",
total=len(data_module.predict_dataset),
):
anchor = batch["anchor"]
seg_anchor = seg_batch["anchor"].int()

# Extract the fov_name and id from the batch
fov_name = batch["index"]["fov_name"][0]
cell_id = batch["index"]["id"].item()

fov_dirs = fov_name.split("/")
# Construct the path to the CSV file
csv_path = os.path.join(
tracks_path, *fov_dirs, f"tracks{fov_name.replace('/', '_')}.csv"
)

# Read the CSV file
df = pd.read_csv(csv_path)

# Find the row with the specified id and extract the track_id
track_id = df.loc[df["id"] == cell_id, "track_id"].values[0]

# Create a boolean mask where segmentation values are equal to the track_id
mask = seg_anchor == track_id
# mask = (seg_anchor > 0)

# Find the most frequent non-zero value in seg_anchor
# unique, counts = np.unique(seg_anchor[seg_anchor > 0], return_counts=True)
# most_frequent_value = unique[np.argmax(counts)]

# # Create a boolean mask where segmentation values are equal to the most frequent value
# mask = (seg_anchor == most_frequent_value)

# Expand the mask to match the anchor tensor shape
mask = mask.expand(1, 3, 1, 224, 224)

# Calculate average values for each channel (background, uninfected, infected) using the mask
background_avg.append(anchor[:, 0, :, :, :][mask[:, 0]].mean().item())
uninfected_avg.append(anchor[:, 1, :, :, :][mask[:, 1]].mean().item())
infected_avg.append(anchor[:, 2, :, :, :][mask[:, 2]].mean().item())

# Convert lists to numpy arrays
background_avg = np.array(background_avg)
uninfected_avg = np.array(uninfected_avg)
infected_avg = np.array(infected_avg)

print("Average values per cell for each mask calculated.")
print("Background average shape:", background_avg.shape)
print("Uninfected average shape:", uninfected_avg.shape)
print("Infected average shape:", infected_avg.shape)

# Save the averages as .npy files
np.save(os.path.join(save_dir, "background_avg.npy"), background_avg)
np.save(os.path.join(save_dir, "uninfected_avg.npy"), uninfected_avg)
np.save(os.path.join(save_dir, "infected_avg.npy"), infected_avg)


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--backbone", type=str, default="resnet50")
parser.add_argument("--margin", type=float, default=0.5)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--schedule", type=str, default="Constant")
parser.add_argument("--log_steps_per_epoch", type=int, default=10)
parser.add_argument("--embedding_len", type=int, default=256)
parser.add_argument("--max_epochs", type=int, default=100)
parser.add_argument("--accelerator", type=str, default="gpu")
parser.add_argument("--devices", type=int, default=1)
parser.add_argument("--num_nodes", type=int, default=1)
parser.add_argument("--log_every_n_steps", type=int, default=1)
parser.add_argument("--num_workers", type=int, default=8)
args = parser.parse_args()
main(args)
Loading

0 comments on commit eb83d8c

Please sign in to comment.