Skip to content

Commit

Permalink
project profile
Browse files Browse the repository at this point in the history
  • Loading branch information
brisvag committed Jan 31, 2024
1 parent 307059f commit 5faa31a
Show file tree
Hide file tree
Showing 2 changed files with 332 additions and 0 deletions.
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,18 @@ flip_z = [
"starfile",
"eulerangles",
]
project_profile = [
"kaleido",
"napari",
"plotly",
]
image = [
"stemia[center_filament]",
"stemia[classify_densities]",
"stemia[create_mask]",
"stemia[extract_z_snapshots]",
"stemia[flip_z]",
"stemia[project_profile]",
]
align_filament_particles = [
"starfile",
Expand Down
326 changes: 326 additions & 0 deletions src/stemia/image/project_profiles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,326 @@
#!/usr/bin/env python3

import click


@click.group()
def cli():
"""Project re-extracted and straightened membranes and get some stats."""
pass


@cli.command()
@click.argument(
"paths", nargs=-1, type=click.Path(exists=True, dir_okay=False, resolve_path=True)
)
@click.option(
"-o", "--output", required=True, type=click.Path(dir_okay=True, resolve_path=True)
)
@click.option("-s", "--chunk-size", type=int, default=35)
@click.option("-f", "--overwrite", is_flag=True)
def prepare(paths, output, chunk_size, overwrite):
"""Generate and select 2D chunked projections for the input data."""
import warnings
from collections import defaultdict
from pathlib import Path

import mrcfile
import napari
import numpy as np
from PIL import Image

def normalize(arr):
arr = arr - np.nanmin(arr)
return arr / np.nanmax(arr)

def get_correct_entry(viewer, event):
for lay in reversed(viewer.layers):
shift = lay._translate_grid[0]
pos = np.array(event.position)[0]
if np.all(pos >= shift):
return lay.name
return None

def open_entry(viewer, event):
if event.modifiers:
return
entry = get_correct_entry(viewer, event)
if entry is not None:
viewer.layers[entry].visible = not viewer.layers[entry].visible

outdir = Path(output)
outdir.mkdir(parents=True, exist_ok=True)

proj_z = defaultdict(dict)
px_sizes = {}
for volume_path in paths:
volume_path = Path(volume_path)
name = volume_path.stem
subdir = outdir / name
subdir.mkdir(parents=True, exist_ok=True)
with mrcfile.open(volume_path) as mrc:
px_size = mrc.voxel_size.x.item()
px_sizes[name] = px_size
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "mean of empty slice")
full = normalize(np.nanmean(mrc.data, axis=0))

chunks = np.split(full, range(chunk_size, full.shape[0], chunk_size), axis=0)
# fuse last two chunks if the last one is too small
if len(chunks) > 1 and chunks[-1].shape[0] < chunk_size / 2:
chunks = chunks[:-2] + [np.concatenate(chunks[-2:], axis=0)]
for idx, chunk in enumerate(chunks):
proj_z[volume_path.stem][idx] = chunk

# save full projections
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "Data array contains NaN values")
with mrcfile.new(
subdir / f"{name}_projZ.mrc", full, overwrite=overwrite
) as mrc:
mrc.voxel_size = px_size
img = Image.fromarray((np.nan_to_num(full) * 255).astype(np.uint8))
img.save(subdir / f"{name}_projZ.png", overwrite=overwrite)

# open napari to view projections and save them
for name, proj in proj_z.items():
print(f"Opening {name}...")
v = napari.Viewer()
for idx, chunk in proj.items():
v.add_image(
chunk,
name=f"{name}_{idx:02}",
contrast_limits=(0, 1),
metadata={"idx": idx},
)

v.mouse_double_click_callbacks.append(open_entry)

v.grid.enabled = True
v.grid.stride = -1
v.grid.shape = (-1, 1)
napari.run()

to_save = [lay for lay in v.layers if lay.visible]

print(f"Saving {name}...")
subdir = outdir / name
for lay in to_save:
idx = lay.metadata["idx"]
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "Data array contains NaN values")
with mrcfile.new(
subdir / f"{name}_projZ_{idx:02}.mrc", lay.data, overwrite=overwrite
) as mrc:
mrc.voxel_size = px_sizes[name]
img = Image.fromarray((np.nan_to_num(lay.data) * 255).astype(np.uint8))
img.save(subdir / f"{name}_projZ_{idx:02}.png", overwrite=overwrite)


@cli.command()
@click.argument(
"proj_dir", type=click.Path(exists=True, dir_okay=True, resolve_path=True)
)
@click.option("-f", "--overwrite", is_flag=True)
def compute(proj_dir, overwrite):
"""Take the outputs from prepare and compute statistics and plots."""
import re
import warnings
from collections import defaultdict
from inspect import cleandoc
from pathlib import Path

import mrcfile
import numpy as np
import pandas as pd
import plotly.express as px
from scipy.signal import find_peaks

def normalize(arr):
arr = arr - np.nanmin(arr)
return arr / np.nanmax(arr)

projs = defaultdict(dict)
projs_avg = defaultdict(dict)
px_sizes = {}

proj_dir = Path(proj_dir)
for subdir in proj_dir.iterdir():
if not subdir.is_dir():
print(f"Ignoring {subdir}: not a directory.")
continue
name = subdir.stem
for proj in subdir.glob("*projZ_??.mrc"):
idx = int(re.search(r"projZ_(\d\d)", proj.stem).group(1))
with mrcfile.open(proj) as mrc:
px_sizes[name] = mrc.voxel_size.x.item()
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "mean of empty slice")
proj = normalize(np.nanmean(mrc.data, axis=0))
projs[name][idx] = proj
projs_avg[name][idx] = (
pd.DataFrame(proj)
.rolling(window=5, min_periods=1)
.mean()
.to_numpy()
.squeeze()
)

# find peaks
high_peaks = {}
low_peaks = {}
for name, ps in projs_avg.items():
high_peaks_idx, high_peaks_props = zip(
*[find_peaks(p, height=0.2) for p in ps.values()]
)
high_peaks[name] = [
loc[prop["peak_heights"].argsort()]
for loc, prop in zip(high_peaks_idx, high_peaks_props)
]
low_peaks_idx, low_peaks_props = zip(
*[find_peaks(1 - p, height=0.2) for p in ps.values()]
)
low_peaks[name] = [
loc[prop["peak_heights"].argsort()]
for loc, prop in zip(low_peaks_idx, low_peaks_props)
]

# align main peaks
for name in low_peaks:
# len(p) // 2 is middle (so we don't shift much)
first_2_peaks = np.array(
[
p[-2:] if len(p) >= 2 else (len(p) // 2, len(p) // 2)
for p in low_peaks[name]
]
)
main_low_peaks = first_2_peaks.min(axis=1)

shifts = main_low_peaks[0] - main_low_peaks
for (k, v), shift in zip(projs[name].items(), shifts):
projs[name][k] = np.roll(v, shift)
for (k, v), shift in zip(projs_avg[name].items(), shifts):
projs_avg[name][k] = np.roll(v, shift)

# plot aligned plots
for name in low_peaks:
df = pd.DataFrame(projs[name])
df.index = np.arange(len(df)) * px_sizes[name]
df.sort_index(inplace=True, axis=1)
df.to_csv(proj_dir / name / "profiles.csv")

fig = px.line(df, title=f"Density profile of {name} by chunks")
fig.update_layout(xaxis_title="Position (Å)", yaxis_title="Normalised density")
fig.show()
fig.write_image(proj_dir / name / "profiles.png", width=1400, height=700)

# also average
fig = px.line(df.mean(axis=1), title=f"Mean density profile of {name}")
fig.update_layout(xaxis_title="Position (Å)", yaxis_title="Normalised density")
fig.show()
fig.write_image(proj_dir / name / "profile_average.png", width=1400, height=700)
# save individual plots
for col in df:
fig = px.line(
df[col], title=f"Density profile of {name}, chunk {int(col):02}"
)
fig.update_layout(
xaxis_title="Position (Å)", yaxis_title="Normalised density"
)
fig.write_image(
proj_dir / name / f"profile_{int(col):02}.png", width=1400, height=700
)

# calculate thicknesses
for name in low_peaks:
# low peaks (membranes)
first_2_peaks = np.array(
[p[-2:] if len(p) >= 2 else (0, 0) for p in low_peaks[name]]
)
thickness_dark = (
np.abs(np.diff(first_2_peaks, axis=1)).squeeze() * px_sizes[name]
)
# high peaks (white band)
first_2_peaks = np.array(
[p[-2:] if len(p) >= 2 else (0, 0) for p in high_peaks[name]]
)
thickness_white = (
np.abs(np.diff(first_2_peaks, axis=1)).squeeze() * px_sizes[name]
)

df = pd.DataFrame(
{"thickness_dark_A": thickness_dark, "thickness_white_A": thickness_white},
index=projs[name].keys(),
)
df.index.name = "chunk"
df.sort_index(inplace=True)

fig = px.violin(df, title=f"Thickness distribution of {name}")
fig.update_layout(yaxis_title="Thickness (Å)")
fig.show()
fig.write_image(proj_dir / name / "thickness.png", width=700, height=700)

df.to_csv(proj_dir / name / "thickness.csv")
print(f"--- {name} ---")
print(df)
print(
f"average thickness (dark): {thickness_dark.mean():.2f}, std={thickness_dark.std():.2f}"
)
print(
f"average thickness (white): {thickness_white.mean():.2f}, std={thickness_white.std():.2f}"
)

print(
cleandoc(
"""
NOTE: Some thicknesses are computed incorrectly due to inconsistent peak heights.
Look through the data and manually fix the ones that look obviously wrong
before continuing to the next step, or they will mess up the results.
"""
)
)


@cli.command()
@click.argument(
"inputs", nargs=-1, type=click.Path(exists=True, dir_okay=True, resolve_path=True)
)
def aggregate(inputs):
"""Aggregate the generated data into general stats about given subsets.
Inputs are subdirectories of the project_dir from compute.
"""
from pathlib import Path

import pandas as pd
import plotly.express as px

dfs = []
names = []

for subdir in inputs:
subdir = Path(subdir)
if not subdir.is_dir():
print(f"Ignoring {subdir}: not a directory.")
continue
names.append(subdir.stem)
dfs.append(pd.read_csv(subdir / "thickness.csv", index_col=0))

df = pd.concat(dfs)
fig = px.violin(
df, title=f'Aggregated thickness distribution for {", ".join(names)}.'
)
fig.update_layout(yaxis_title="Thickness (Å)")
fig.show()
fig.write_image(
subdir.parent / f'thickness_aggregated_{"_".join(names)}.png',
width=700,
height=700,
)
summary = df.describe()
print(summary)
summary.to_csv(subdir.parent / f'thickness_aggregated_{"_".join(names)}.csv')


if __name__ == "__main__":
cli()

0 comments on commit 5faa31a

Please sign in to comment.