Skip to content

Commit

Permalink
matched colormap for labels & segmentations
Browse files Browse the repository at this point in the history
  • Loading branch information
mese79 committed Jan 16, 2024
1 parent 22de864 commit 3d6ee3e
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
3 changes: 3 additions & 0 deletions src/napari_sam_labeling_tools/_sam_rf_segmentation_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ def check_label_layers(self, event: Event):
# to handle layer's name change by user
layer.events.name.disconnect()
layer.events.name.connect(self.check_label_layers)
layer.colormap = colormaps.create_colormap(10)[0]
if "Segmentation" in layer.name:
self.prediction_layer_combo.addItem(layer.name)
else:
Expand Down Expand Up @@ -536,6 +537,8 @@ def predict_slice(self, rf_model, slice_index, img_h, img_w):
-1, SAM.PATCH_CHANNELS + SAM.EMBEDDING_SIZE
)
predictions = rf_model.predict(features).astype(np.uint8)
# to match the segmentation colormap with the labels
predictions[predictions > 0] += 1
segmentation_image = predictions.reshape(img_h, img_w)

# check for postprocessing
Expand Down
9 changes: 5 additions & 4 deletions src/napari_sam_labeling_tools/utils/colormaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_colormap1():
return cm


def create_colormap(num_colors, bright=True, black_first=True, seed=4567):
def create_colormap(num_colors, bright=True, black_first=True, seed=777):
if num_colors < 10:
num_colors = 10

Expand All @@ -65,13 +65,14 @@ def create_colormap(num_colors, bright=True, black_first=True, seed=4567):
low = [0.0, 0.55, 0.9]
high = 1.0

rng = np.random.default_rng(seed)
hues = np.linspace(start=low[0], stop=high, num=num_colors)
rng.shuffle(hues)
hsv_colors = np.stack([
np.linspace(start=low[0], stop=high, num=num_colors),
hues,
np.linspace(start=low[1], stop=high, num=num_colors),
np.linspace(start=low[2], stop=high, num=num_colors)
], axis=1)
rng = np.random.default_rng(seed)
rng.shuffle(hsv_colors)

rgba_colors = np.zeros((num_colors, 4))
for i in range(num_colors):
Expand Down
5 changes: 4 additions & 1 deletion src/napari_sam_labeling_tools/widgets/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import napari
import numpy as np

from napari_sam_labeling_tools.utils import colormaps


def get_layer(napari_viewer, name, layer_types):
for layer in napari_viewer.layers:
Expand All @@ -21,6 +23,7 @@ def add_labels_layer(napari_viewer: napari.Viewer):
for s, sc in zip(scene_size, scale)
]
empty_labels = np.zeros(shape, dtype=np.uint8)
napari_viewer.add_labels(
layer = napari_viewer.add_labels(
empty_labels, name="Labels", translate=np.array(corner), scale=scale
)
layer.colormap = colormaps.create_colormap(10)[0]

0 comments on commit 3d6ee3e

Please sign in to comment.