Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: New image, v2 #312

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 90 additions & 16 deletions src/napari_imagej/widgets/parameter_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
They should align with a SciJava ModuleItem that satisfies some set of conditions.
"""

from __future__ import annotations

import importlib
from dataclasses import dataclass
from functools import lru_cache
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, TYPE_CHECKING

from imagej.images import _imglib2_types
from jpype import JArray, JClass, JInt, JLong
Expand All @@ -24,13 +26,29 @@
request_values,
)
from napari import current_viewer
from napari.layers import Layer
from napari.layers import Layer, Image
from napari.utils._magicgui import get_layers
from numpy import dtype
from scyjava import numeric_bounds

from napari_imagej.java import jc

if TYPE_CHECKING:
from typing import Literal, Sequence


# Generally, Python libraries treat the dimensions i of an array with n
# dimensions as the CONVENTIONAL_DIMS[n][i] axis
# FIXME: Also in widget_utils.py
CONVENTIONAL_DIMS = [
[],
["X"],
["Y", "X"],
["Z", "Y", "X"],
["T", "Y", "X", "C"],
["T", "Z", "Y", "X", "C"],
]


def widget_supported_java_types() -> List[JClass]:
"""
Expand Down Expand Up @@ -147,6 +165,49 @@ def value(self, value: Any):
return Widget


class MutableDimWidget(Container):
def __init__(
self,
val: int = 256,
idx: int = 0,
dim: Literal["X", "Y", "Z", "C", "T", ""] = "",
**kwargs,
):
layer_tooltip = f"Parameters for the dimension of index {idx}"
self.size_spin = SpinBox(value=val, min=1)
choices = ["X", "Y", "Z", "C", "T"]
layer_kwargs = kwargs.copy()
layer_kwargs["nullable"] = True
self.layer_select = ComboBox(
choices=choices, tooltip=layer_tooltip, **layer_kwargs
)
self._nullable = True
kwargs["widgets"] = [self.size_spin, self.layer_select]
kwargs["labels"] = False
kwargs["layout"] = "horizontal"
kwargs.pop("value")
kwargs.pop("nullable")
super().__init__(**kwargs)
self.margins = (0, 0, 0, 0)

@property
def value(self) -> tuple[int, str]:
return (self.size_spin.value, self.layer_select.value)

@value.setter
def value(self, v: tuple[int, str]) -> None:
self.size_spin.value = v[0]
self.layer_select.value = v[1]


class MutableDimsWidget(ListEdit):
def __init__(
self,
**kwargs,
):
super().__init__(**kwargs)


class MutableOutputWidget(Container):
"""
A ComboBox widget combined with a button that creates new layers.
Expand Down Expand Up @@ -205,12 +266,23 @@ def _default_layer(self) -> Optional[Layer]:
selection_name = widget.current_choice
if selection_name != "":
return current_viewer().layers[selection_name]
return None

def _default_new_shape(self):
def _default_new_shape(self) -> Sequence[tuple[int, str]]:
guess = self._default_layer()
if guess:
return guess.data.shape
return [512, 512]
data = guess.data
# xarray has dims, otherwise use conventions
if hasattr(data, "dims"):
dims = data.dims
# Special case: RGB
elif isinstance(guess, Image) and guess.rgb:
dims = list(CONVENTIONAL_DIMS[len(data.shape) - 1])
dims.append("C")
else:
dims = list(CONVENTIONAL_DIMS[len(data.shape)])
return [t for t in zip(data.shape, dims)]
return [(512, "Y"), (512, "X")]

def _default_new_type(self) -> str:
"""
Expand All @@ -237,12 +309,10 @@ def create_new_image(self) -> None:
"""

# Array types that are always included
backing_choices = ["NumPy"]
backing_choices = ["xarray", "NumPy"]
# Array types that may be present
if importlib.util.find_spec("zarr"):
backing_choices.append("Zarr")
if importlib.util.find_spec("xarray"):
backing_choices.append("xarray")

# Define the magicgui widget for parameter harvesting
params = request_values(
Expand All @@ -253,16 +323,19 @@ def create_new_image(self) -> None:
options=dict(tooltip="If blank, a name will be generated"),
),
shape=dict(
annotation=List[int],
annotation=List[MutableDimWidget],
value=self._default_new_shape(),
options=dict(
tooltip="By default, the shape of the first Layer input",
options=dict(min=0, max=2**31 - 10),
layout="vertical",
options=dict(
widget_type=MutableDimWidget,
),
tooltip="The size of each image axis",
),
),
array_type=dict(
annotation=str,
value="NumPy",
value="xarray",
options=dict(
tooltip="The backing data array implementation",
choices=backing_choices,
Expand Down Expand Up @@ -295,7 +368,7 @@ def _add_new_image(self, params: dict):
import numpy as np

data = np.full(
shape=tuple(params["shape"]),
shape=tuple(p[0] for p in params["shape"]),
fill_value=params["fill_value"],
dtype=params["data_type"],
)
Expand All @@ -305,7 +378,7 @@ def _add_new_image(self, params: dict):
import zarr

data = zarr.full(
shape=params["shape"],
shape=tuple(p[0] for p in params["shape"]),
fill_value=params["fill_value"],
dtype=params["data_type"],
)
Expand All @@ -317,10 +390,11 @@ def _add_new_image(self, params: dict):

data = xarray.DataArray(
data=np.full(
shape=tuple(params["shape"]),
shape=tuple(p[0] for p in params["shape"]),
fill_value=params["fill_value"],
dtype=params["data_type"],
)
),
dims=tuple(p[1] for p in params["shape"]),
)

# give the data array to the viewer.
Expand Down
22 changes: 12 additions & 10 deletions src/napari_imagej/widgets/widget_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@
info_for,
)

# Generally, Python libraries treat the dimensions i of an array with n
# dimensions as the CONVENTIONAL_DIMS[n][i] axis
CONVENTIONAL_DIMS = [
[],
["X"],
["Y", "X"],
["Z", "Y", "X"],
["T", "Y", "X", "C"],
["T", "Z", "Y", "X", "C"],
]


def python_actions_for(
result: "jc.SearchResult", output_signal: Signal, parent_widget: QWidget = None
Expand Down Expand Up @@ -168,15 +179,6 @@ def __init__(self, title: str, choices: List[Layer], required=True):
class DimsComboBox(QFrame):
"""A QFrame used to map the axes of a Layer to dimension labels"""

dims = [
[],
["X"],
["Y", "X"],
["Z", "Y", "X"],
["T", "Y", "X", "C"],
["T", "Z", "Y", "X", "C"],
]

def __init__(self, combo_box: LayerComboBox):
super().__init__()
self.selection_box: LayerComboBox = combo_box
Expand All @@ -199,7 +201,7 @@ def update(self, index: int):
# Determine the selected layer
selected = self.selection_box.combo.itemData(index)
# Guess dimension labels for the selection
guess = self.dims[len(selected.data.shape)]
guess = CONVENTIONAL_DIMS[len(selected.data.shape)]
# Create dimension selectors for each dimension of the selection.
for i, g in enumerate(guess):
self.layout().addWidget(
Expand Down
24 changes: 17 additions & 7 deletions tests/widgets/test_parameter_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,28 @@ def test_mutable_output_default_parameters(

# Assert when no selection, output shape is default
assert input_widget.current_choice == ""
assert output_widget._default_new_shape() == [512, 512]
assert output_widget._default_new_shape() == [(512, "Y"), (512, "X")]
assert output_widget._default_new_type() == "float64"

# Add new image
shape = (128, 128, 3)
# Add new Z-stack
shape = (3, 128, 128)
import numpy as np

current_viewer().add_image(data=np.ones(shape, dtype=np.int32), name="img")
assert input_widget.current_choice == "img"
assert output_widget._default_new_shape() == shape
current_viewer().add_image(data=np.ones(shape, dtype=np.int32), name="Z")
assert input_widget.current_choice == "Z"
assert output_widget._default_new_shape() == [(3, "Z"), (128, "Y"), (128, "X")]
assert output_widget._default_new_type() == "int32"

# Add new RGB image
shape = (128, 128, 3)
import numpy as np

current_viewer().layers.clear()
current_viewer().add_image(data=np.ones(shape, dtype=np.int16), name="RGB")
assert input_widget.current_choice == "RGB"
assert output_widget._default_new_shape() == [(128, "Y"), (128, "X"), (3, "C")]
assert output_widget._default_new_type() == "int16"


def test_mutable_output_dtype_choices(
input_widget: ComboBox, output_widget: MutableOutputWidget
Expand Down Expand Up @@ -147,7 +157,7 @@ def test_mutable_output_add_new_image(
params = {
"name": "foo",
"array_type": choice,
"shape": (100, 100, 3),
"shape": ((100, "Y"), (100, "X"), (3, "C")),
"fill_value": 3.0,
"data_type": np.int32,
}
Expand Down