Skip to content

Commit

Permalink
some cool classifying
Browse files Browse the repository at this point in the history
  • Loading branch information
brisvag committed Aug 4, 2023
1 parent 9385a10 commit 59c4972
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 19 deletions.
7 changes: 4 additions & 3 deletions stemia/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@


def _print_tree(ctx, param, value):
print_command_tree(cli)
ctx.exit()
if value:
print_command_tree(cli)
ctx.exit()


@click.group(name='stemia', context_settings=dict(help_option_names=['-h', '--help'], show_default=True))
@click.version_option(version=version)
@click.option('-l', '--list', is_flag=True, is_eager=True, expose_value=False, callback=_print_tree,
help='print all the available commands')
def cli(list):
def cli():
"""
Main entry point for stemia. Several subcommands are available.
Expand Down
63 changes: 49 additions & 14 deletions stemia/image/classify_densities.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,29 @@


@click.command()
@click.argument('classes', nargs=-1, type=click.Path(exists=True, dir_okay=False, resolve_path=True))
def cli(classes):
@click.argument('stacks', nargs=-1, type=click.Path(exists=True, dir_okay=False, resolve_path=True))
@click.option('-c', '--max-classes', default=5, type=int)
def cli(stacks, max_classes):
"""
Do hierarchical classification of particle stacks based on densities.
"""
from pathlib import Path

import mrcfile
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from rich.progress import Progress
from scipy.cluster.hierarchy import dendrogram, linkage
from rich import print
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
import plotly.express as px

from stemia.utils.image_processing import compute_dist_field, create_mask_from_field

if not classes:
if not stacks:
return

with mrcfile.open(classes[0], header_only=True) as mrc:
with mrcfile.open(stacks[0], header_only=True) as mrc:
shape = mrc.header[['nx', 'ny']].item()

radius = min(shape) / 2
Expand All @@ -33,22 +40,50 @@ def cli(classes):

field_squared = dist_field**2

images = {}

print('Running with {max_classes} classes.')

df = pd.DataFrame(columns=['total_density', 'radius_of_gyration'])
df.index.name = 'image'
with Progress() as progress:
for cl in progress.track(classes, description='Reading data...'):
data = mrcfile.read(cl)
for st in progress.track(stacks, description='Reading data...'):
data = mrcfile.read(st)
images[Path(st).stem] = data
for idx, img in enumerate(progress.track(data, description='Calculating features...')):
img_name = f'{Path(cl).stem}_{idx}'
img_name = f'{Path(st).stem}_{idx}'
img -= img.min()
img /= img.mean()
img *= mask
total_density = img.sum()
gyr = np.sqrt(np.sum(field_squared * img))
df.loc[img_name] = [total_density, gyr]
progress.update(progress.task_ids[-1], visible=False)

Z = linkage(df.to_numpy(), 'centroid', optimal_ordering=True)
from matplotlib import pyplot as plt
fig = plt.figure(figsize=(25, 10))
dn = dendrogram(Z)
plt.savefig('classification.png')
df.to_csv('classification.csv', sep='\t')
proc_task = progress.add_task('Classifying...', total=3)

Z = linkage(df.to_numpy(), 'centroid', optimal_ordering=True)
progress.update(proc_task, advance=1)
classes = fcluster(Z, t=max_classes, criterion='maxclust')
progress.update(proc_task, advance=1)
df['class'] = classes

fig = plt.figure(figsize=(50, 20))
_ = dendrogram(Z)
plt.savefig('classification.png')
df.to_csv('classification.csv', sep='\t')
df['name'] = df.index

fig = px.scatter(df, x='total_density', y='radius_of_gyration', color='class', hover_name='name')
fig.show()
progress.update(proc_task, advance=1)

for cl, df_cl in progress.track(df.groupby('class'), description='Splitting classes...'):
stacked = []
for img in df_cl.index:
*img_name, idx = img.split('_')
img_name = '_'.join(img_name)
idx = int(idx)
stacked.append(images[img_name][idx])
mrc = mrcfile.new(f'{img_name}_class_{cl:04}.mrc', np.stack(stacked), overwrite=True)
mrc.close()
4 changes: 2 additions & 2 deletions stemia/utils/click.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def print_command_tree(cli, prefix='', last=True):
last_sub = len(subs) - i == 1
end_prefix = '└──' if last_sub else '├──'
print(
f'[white]{prefix + end_prefix}[/][bold]{sub.name}[/]: '
f'[white]{prefix + end_prefix} [/][bold]{sub.name}[/]: '
f'[italic white]{(sub.__doc__ or "-").strip().splitlines()[0]}[/]'
)
sub_prefix = ' ' if last_sub else '│ '
sub_prefix = ' ' if last_sub else '│ '
print_command_tree(sub, prefix=prefix + sub_prefix, last=last_sub)

0 comments on commit 59c4972

Please sign in to comment.