Skip to content

Commit

Permalink
Merge pull request #12 from JesseMckinzie/feature_heatmap
Browse files Browse the repository at this point in the history
Add feature calculation heatmap
  • Loading branch information
sameeul authored Jun 5, 2024
2 parents 1105b4e + 002ce2f commit 349322a
Show file tree
Hide file tree
Showing 4 changed files with 283 additions and 16 deletions.
269 changes: 253 additions & 16 deletions napari_nyxus/nyx_napari.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
from qtpy.QtWidgets import QWidget, QScrollArea, QTableWidget, QVBoxLayout,QTableWidgetItem, QLineEdit, QLabel, QHBoxLayout, QAbstractItemView
from qtpy.QtWidgets import QWidget, QCheckBox, QScrollArea, QTableWidget, QVBoxLayout,QTableWidgetItem, QLineEdit, QLabel, QHBoxLayout, QPushButton, QComboBox
from qtpy.QtCore import Qt, QTimer
from qtpy import QtCore, QtGui, QtWidgets
from qtpy.QtGui import QColor
from superqt import QLabeledDoubleRangeSlider
import napari
from napari.layers import Image, Labels
from napari.qt.threading import thread_worker
from napari.utils.notifications import show_info
from magicgui import magic_factory

from enum import Enum
import numpy as np
import pandas as pd
import dask
from filepattern import FilePattern
import tempfile

import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

from napari_nyxus import util
from napari_nyxus.util import util
from napari_nyxus.util import rotated_header

#from napari_nyxus.table import TableWidget, add_table, get_table
from napari_skimage_regionprops import TableWidget, add_table, get_table
Expand Down Expand Up @@ -85,6 +93,8 @@ def __init__(

self.labels_added = False

self.num_annotations = 0

# Check for CUDA enable GPU if requested
if (use_CUDA_Enabled_GPU):
import subprocess
Expand Down Expand Up @@ -160,7 +170,7 @@ def run(self):
show_info("Calculating features...")
self._run_calculate()
self.add_features_table()

self.is_heatmap_added = False


def _run_calculate(self):
Expand All @@ -175,27 +185,56 @@ def _run_calculate(self):

#@thread_worker
def _calculate(self):
""" Calculates the features using Nyxus
""" Calculates the features using Nyxus
"""
if (type(self.intensity.data) == dask.array.core.Array):
self.batched = True
self._calculate_out_of_core()
else:
self.result = self.nyxus_object.featurize(self.intensity.data, self.segmentation.data)
self.result = self.nyxus_object.featurize(self.intensity.data, self.segmentation.data, intensity_names=[self.intensity.name], label_names=[self.segmentation.name])


def _calculate_out_of_core(self):
""" Out of core calculations for when dataset size is larger than what Napari
loads into memory
"""
results = []

from os import walk

# Get files from directory and skip hidden files
filenames = [f for f in next(walk(self.intensity.source.path), (None, None, []))[2] if not f.startswith('.')]
filenames.sort() # sort files to be in same order they appear in napari


self.result = None
names_index = 0

for idx in np.ndindex(self.intensity.data.numblocks):
results.append(self.nyxus_object.featurize(
self.intensity.data.blocks[idx].compute(),
self.segmentation.data.blocks[idx].compute()))

num_files = len(self.intensity.data.blocks[idx])
names = filenames[names_index:names_index + num_files]

# set DataFrame value after first batch is processed
if self.result is None:
self.result = self.nyxus_object.featurize(
self.intensity.data.blocks[idx].compute(),
self.segmentation.data.blocks[idx].compute(),
intensity_names=names,
label_names=names
)

else: # Concat to first batch results
self.result = pd.concat([self.result,
self.nyxus_object.featurize(
self.intensity.data.blocks[idx].compute(),
self.segmentation.data.blocks[idx].compute(),
intensity_names=names,
label_names=names)],
ignore_index=True)

names_index += num_files


self.result = pd.concat(results)
#self.result = pd.concat(results, ignore_index=True)


def add_features_table(self):
Expand All @@ -206,6 +245,71 @@ def add_features_table(self):

self._add_features_table()

def add_feature_calculation_table_options(self):

# create window for feature calculation table functionality
win = FeaturesWidget()
scroll = QScrollArea()
layout = QVBoxLayout()
widget_table = QWidget()
widget_table.setLayout(layout)

# add combobox for selecting heatmap type
self.heatmap_combobox = QComboBox()
self.heatmap_combobox.addItems(plt.colormaps()[:5])
self.heatmap_combobox.addItem('gray')
widget_table.layout().addWidget(self.heatmap_combobox)

self.remove_number_checkbox = QCheckBox("Hide feature calculation values")
widget_table.layout().addWidget(self.remove_number_checkbox)

# button to create heatmap
heatmap_button = QPushButton("Generate heatmap")
heatmap_button.clicked.connect(self.generate_heatmap)

widget_table.layout().addWidget(heatmap_button)

# add text box for selecting column for extracting annotations
self.column_box = QComboBox()
self.column_box.addItems(['intensity_image', 'mask_image'])

# text box for filepattern for extracting annotations
self.filepattern_box = QLineEdit()
self.filepattern_box.setPlaceholderText("Enter filepattern (ex: r{r:d+}_c{c:d+}.tif)")
self.filepattern_box.textChanged.connect(self.check_annotations_input)

# add text box for selecting which filepattern variable to extract annotation
self.annotation_box = QLineEdit()
self.annotation_box.setPlaceholderText("Enter annotation to extract (ex:r)")
self.annotation_box.textChanged.connect(self.check_annotations_input)

self.annotation_button = QPushButton("Extract annotation")
self.annotation_button.clicked.connect(self.extract_annotation)

# add text for selecting column(s) to sort by
self.sort_by_box = QLineEdit()
self.sort_by_box.setPlaceholderText("Enter column to sort rows by")
self.sort_by_box.textChanged.connect(self.check_sort_input)

# add button for sorting columns
self.sort_button = QPushButton("Sort")
self.sort_button.clicked.connect(self._sort)

# add widgets for feature calculations table
widget_table.layout().addWidget(self.column_box)
widget_table.layout().addWidget(self.filepattern_box)
widget_table.layout().addWidget(self.annotation_box)
widget_table.layout().addWidget(self.annotation_button)
widget_table.layout().addWidget(self.sort_by_box)
widget_table.layout().addWidget(self.sort_button)

scroll.setWidget(widget_table)
win.setLayout(layout)
win.setWindowTitle("Feature Table Options")


self.viewer.window.add_dock_widget(win)


def _add_features_table(self):
""" Adds table to Napari viewer
Expand All @@ -225,18 +329,151 @@ def _add_features_table(self):
add_table(labels_layer, self.viewer)

widget_table = get_table(labels_layer, self.viewer)

self.table = widget_table._view

self.table.setHorizontalHeader(rotated_header.RotatedHeaderView(self.table))

self.table.cellDoubleClicked.connect(self.cell_was_clicked)

# remove label clicking event to use our own
widget_table._layer.mouse_drag_callbacks.remove(widget_table._clicked_labels)

try:
widget_table._layer.mouse_drag_callbacks.remove(widget_table._clicked_labels)
except:
print('No mouse drag event to remove')

# add new label clicking event
self.table.horizontalHeader().sectionDoubleClicked.connect(self.onHeaderClicked)

self.add_feature_calculation_table_options()

def check_sort_input(self):

if self.sort_by_box.text():
self.sort_button.setEnabled(True)
else:
self.sort_button.setEnabled(False)

def check_annotations_input(self):

if self.filepattern_box.text() and self.annotation_box.text():
self.annotation_button.setEnabled(True)
else:
self.annotation_button.setEnabled(False)

def _sort(self):

sort_columns = self.sort_by_box.text().split()

# check if columns are valid
for column in sort_columns:
if column not in self.result.columns.to_list():
show_info(f'Column name \"{column}\" is not a valid column.')
return

if (len(sort_columns) == 1): # sort table in place if only one sorting column is passed
sort_column_index = self.result.columns.get_loc(sort_columns[0])

self.table.sortItems(sort_column_index)

else: # sort datafame when multiple columns are passed

#self.result.sort_values(by=sort_columns, ascending=[i % 2 == 0 for i in range(len(sort_columns))], inplace=True)
self.result.sort_values(by=sort_columns, inplace=True)

for row in range(self.result.shape[0]):
for col in range(self.result.shape[1]):
self.table.setItem(row, col, QTableWidgetItem(str(self.result.iat[row, col])))

if self.is_heatmap_added:
self.generate_heatmap()


def generate_heatmap(self):

remove = self.remove_number_checkbox.isChecked()

if remove:
row_height = self.table.rowHeight(1)
column_width = self.table.columnWidth(1)

width = min(row_height, column_width)

self.table.horizontalHeader().setDefaultSectionSize(width)
self.table.verticalHeader().setDefaultSectionSize(width)


for col in range(3 + self.num_annotations, self.result.shape[1]):

# Get the column data
column_data = self.result.iloc[:, col]

# Normalize feature calculation values between 0 and 1 for this column
normalized_values = (column_data - column_data.min()) / (column_data.max() - column_data.min())

# Map normalized values to colors using a specified colormap
colormap = plt.get_cmap(self.heatmap_combobox.currentText())
colors = (colormap(normalized_values) * 255).astype(int) # Multiply by 255 to convert to QColor range
# Set background color for each item in the column

for row in range(self.result.shape[0]):

if remove:
self.table.item(row, col).setText('') # remove feature calculation value from cell

self.table.item(row, col).setBackground(QColor(colors[row][0], colors[row][1], colors[row][2], colors[row][3]))

self.is_heatmap_added = True

def extract_annotation(self, event):
import os

column_name = self.column_box.currentText()
file_pattern = self.filepattern_box.text()
annotation = self.annotation_box.text()

# write filenames to txt file to feed into filepattern (todo: update filepattern to remove need for text file)
# use temp directory to allow filepattern (another process) to open temp file on Windows
with tempfile.TemporaryDirectory() as td:
f_name = os.path.join(td, 'rows')
try:
row_values = self.result[column_name].to_list()
except:
show_info("Invalid column name")
return
with open(f_name, 'w') as fh:
for row in row_values:
fh.write(f"{row}\n")

fp = FilePattern(f_name, file_pattern)

found_annotations = []
for annotations, file in fp:
if annotation not in annotations:
continue
found_annotations.append(annotations[annotation])

if (len(found_annotations) != len(row_values)):
show_info('Error extracting annotations. Check that the filenames match the filepattern.')
return

annotations_position = 3

try:
self.result.insert(annotations_position, annotation, found_annotations, allow_duplicates=False)
except:
show_info("Error inserting annotations column. Check that the name of the annotations column is unique.")
return

self.table.insertColumn(annotations_position)

self.table.setHorizontalHeaderItem(annotations_position, QTableWidgetItem(annotation))

for row in range(self.table.rowCount()):
self.table.setItem(row, annotations_position, QTableWidgetItem(str(found_annotations[row])))

self.num_annotations += 1

def cell_was_clicked(self, event):

if (self.batched):
Expand All @@ -263,12 +500,12 @@ def highlight_value(self, value):

for ix, iy in np.ndindex(self.seg.shape):

if (int(self.seg[ix, iy]) == int(value)):
if (int(self.seg[ix, iy]) == int(float(value))):

if (self.labels[ix, iy] != 0):
self.labels[ix, iy] = 0
else:
self.labels[ix, iy] = int(value)
self.labels[ix, iy] = int(float(value))

if (not removed):
self.current_label += 1
Expand Down
28 changes: 28 additions & 0 deletions napari_nyxus/util/rotated_header.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from qtpy.QtCore import *
from qtpy.QtGui import *
from qtpy.QtWidgets import QHeaderView

class RotatedHeaderView(QHeaderView):
def __init__(self, parent=None):
super(RotatedHeaderView, self).__init__(Qt.Horizontal, parent)
self.setMinimumSectionSize(20)

def paintSection(self, painter, rect, logicalIndex ):
painter.save()
# translate the painter such that rotate will rotate around the correct point
painter.translate(rect.x()+rect.width(), rect.y())
painter.rotate(90)
# and have parent code paint at this location
newrect = QRect(0,0,rect.height(),rect.width())
super(RotatedHeaderView, self).paintSection(painter, newrect, logicalIndex)
painter.restore()

def minimumSizeHint(self):
size = super(RotatedHeaderView, self).minimumSizeHint()
size.transpose()
return size

def sectionSizeFromContents(self, logicalIndex):
size = super(RotatedHeaderView, self).sectionSizeFromContents(logicalIndex)
size.transpose()
return size
File renamed without changes.
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ install_requires =
qtpy
superqt
napari-skimage-regionprops>=0.10.1
matplotlib
filepattern>=2.0.0

[options.entry_points]
napari.manifest =
Expand Down

0 comments on commit 349322a

Please sign in to comment.