diff --git a/pyclesperanto/_array.py b/pyclesperanto/_array.py index f4e47d72..9df57ccd 100644 --- a/pyclesperanto/_array.py +++ b/pyclesperanto/_array.py @@ -7,14 +7,17 @@ import numpy as np import warnings + def __str__(self) -> str: return self.get().__str__() + def __repr__(self) -> str: repr_str = self.get().__repr__() extra_info = f"mtype={self.mtype}" return repr_str[:-1] + f", {extra_info})" + def set(self, array: np.ndarray) -> None: if array.dtype != self.dtype: warnings.warn( @@ -35,6 +38,7 @@ def set(self, array: np.ndarray) -> None: self._write(array) return self + def get(self) -> np.ndarray: caster = { "float32": self._read_float32, @@ -49,11 +53,13 @@ def get(self) -> np.ndarray: } return caster[self.dtype.name]() + def __array__(self, dtype=None) -> np.ndarray: - if dtype is None: - return self.get() - else: - return self.get().astype(dtype) + if dtype is None: + return self.get() + else: + return self.get().astype(dtype) + # missing operators: # __setitem__ @@ -66,32 +72,33 @@ def __array__(self, dtype=None) -> np.ndarray: setattr(Array, "__str__", __str__) setattr(Array, "__repr__", __repr__) setattr(Array, "__array__", __array__) -setattr(Array,"astype",_operators.astype) -setattr(Array,"max",_operators.max) -setattr(Array,"min",_operators.min) -setattr(Array,"sum",_operators.sum) -setattr(Array,"__iadd__",_operators.__iadd__) -setattr(Array,"__sub__",_operators.__sub__) -setattr(Array,"__div__",_operators.__div__) -setattr(Array,"__truediv__",_operators.__truediv__) -setattr(Array,"__idiv__",_operators.__idiv__) -setattr(Array,"__itruediv__",_operators.__itruediv__) -setattr(Array,"__mul__",_operators.__mul__) -setattr(Array,"__imul__",_operators.__imul__) -setattr(Array,"__gt__",_operators.__gt__) -setattr(Array,"__ge__",_operators.__ge__) -setattr(Array,"__lt__",_operators.__lt__) -setattr(Array,"__le__",_operators.__le__) -setattr(Array,"__eq__",_operators.__eq__) -setattr(Array,"__ne__",_operators.__ne__) -setattr(Array,"__pow__",_operators.__pow__) -setattr(Array,"__ipow__",_operators.__ipow__) -setattr(Array,"_plt_to_png",_operators._plt_to_png) -setattr(Array,"_png_to_html",_operators._png_to_html) -setattr(Array,"_repr_html_",_operators._repr_html_) +setattr(Array, "astype", _operators.astype) +setattr(Array, "max", _operators.max) +setattr(Array, "min", _operators.min) +setattr(Array, "sum", _operators.sum) +setattr(Array, "__iadd__", _operators.__iadd__) +setattr(Array, "__sub__", _operators.__sub__) +setattr(Array, "__div__", _operators.__div__) +setattr(Array, "__truediv__", _operators.__truediv__) +setattr(Array, "__idiv__", _operators.__idiv__) +setattr(Array, "__itruediv__", _operators.__itruediv__) +setattr(Array, "__mul__", _operators.__mul__) +setattr(Array, "__imul__", _operators.__imul__) +setattr(Array, "__gt__", _operators.__gt__) +setattr(Array, "__ge__", _operators.__ge__) +setattr(Array, "__lt__", _operators.__lt__) +setattr(Array, "__le__", _operators.__le__) +setattr(Array, "__eq__", _operators.__eq__) +setattr(Array, "__ne__", _operators.__ne__) +setattr(Array, "__pow__", _operators.__pow__) +setattr(Array, "__ipow__", _operators.__ipow__) +setattr(Array, "_plt_to_png", _operators._plt_to_png) +setattr(Array, "_png_to_html", _operators._png_to_html) +setattr(Array, "_repr_html_", _operators._repr_html_) Image = Union[np.ndarray, Array] + def is_image(any_array): return ( isinstance(any_array, np.ndarray) diff --git a/pyclesperanto/_functionalities.py b/pyclesperanto/_functionalities.py index 68c4b829..9a47a014 100644 --- a/pyclesperanto/_functionalities.py +++ b/pyclesperanto/_functionalities.py @@ -66,7 +66,7 @@ def imshow( labels: Optional[bool] = False, min_display_intensity: Optional[float] = None, max_display_intensity: Optional[float] = None, - color_map: Optional[str]=None, + color_map: Optional[str] = None, plot=None, colorbar: Optional[bool] = False, colormap: Union[str, None] = None, @@ -172,7 +172,9 @@ def imshow( plt.title(title) -def operations(must_have_categories : list = None, must_not_have_categories : list = None) -> dict: +def operations( + must_have_categories: list = None, must_not_have_categories: list = None +) -> dict: """Retrieve a dictionary of operations, which can be filtered by annotated categories. Parameters @@ -205,26 +207,38 @@ def operations(must_have_categories : list = None, must_not_have_categories : li keep_it = True if hasattr(operation, "categories") and operation.categories is not None: if must_have_categories is not None: - if not all(item in operation.categories for item in must_have_categories): + if not all( + item in operation.categories for item in must_have_categories + ): keep_it = False if must_not_have_categories is not None: - if any(item in operation.categories for item in must_not_have_categories): + if any( + item in operation.categories for item in must_not_have_categories + ): keep_it = False else: if must_have_categories is not None: keep_it = False - if (keep_it): + if keep_it: result[operation_name] = operation return result -def list_operations(search_term = None): +def list_operations(search_term=None): ops = operations(search_term) for name in ops: func = ops[name] - if hasattr(func, 'fullargspec'): - print(name + "(" + str(func.fullargspec.args).replace('[','').replace(']','').replace('\'','') + ")") + if hasattr(func, "fullargspec"): + print( + name + + "(" + + str(func.fullargspec.args) + .replace("[", "") + .replace("]", "") + .replace("'", "") + + ")" + ) else: - print(name) \ No newline at end of file + print(name) diff --git a/pyclesperanto/_memory.py b/pyclesperanto/_memory.py index 3e8e292b..9f9fdc36 100644 --- a/pyclesperanto/_memory.py +++ b/pyclesperanto/_memory.py @@ -6,13 +6,14 @@ import numpy as np from typing import Tuple, Optional + def create( shape: Tuple[int, ...], dtype: Optional[type] = None, mtype: Optional[str] = None, device: Optional[Device] = None, ) -> Image: - """ Create a new image on the device. + """Create a new image on the device. Parameters ---------- @@ -36,7 +37,10 @@ def create( dtype = np.float32 if dtype in [float, np.float64]: dtype = np.float32 - warnings.warn("Warning: float64 type is not a supported in GPUs. Casting data to float32 type.", UserWarning) + warnings.warn( + "Warning: float64 type is not a supported in GPUs. Casting data to float32 type.", + UserWarning, + ) if mtype is None: mtype = "buffer" if device is None: @@ -50,7 +54,7 @@ def create_like( mtype: Optional[str] = None, device: Optional[Device] = None, ) -> Image: - """ Create a new image on the device with the same shape and dtype as the input image. + """Create a new image on the device with the same shape and dtype as the input image. Parameters ---------- @@ -72,7 +76,10 @@ def create_like( dtype = array.dtype if dtype in [float, np.float64]: dtype = np.float32 - warnings.warn("Warning: float64 type is not a supported in GPUs. Casting data to float32 type.", UserWarning) + warnings.warn( + "Warning: float64 type is not a supported in GPUs. Casting data to float32 type.", + UserWarning, + ) return create(array.shape, dtype, mtype, device) @@ -82,7 +89,7 @@ def push( mtype: Optional[str] = None, device: Optional[Device] = None, ) -> Image: - """ Create a new image on the device and push the input image into it. + """Create a new image on the device and push the input image into it. Parameters ---------- @@ -110,7 +117,7 @@ def push( def pull(array: Image) -> np.ndarray: - """ Pull the input image from the device to the host. + """Pull the input image from the device to the host. Parameters ---------- diff --git a/pyclesperanto/_operators.py b/pyclesperanto/_operators.py index 3916bdc7..ee05e7f9 100644 --- a/pyclesperanto/_operators.py +++ b/pyclesperanto/_operators.py @@ -20,6 +20,7 @@ _supported_numeric_types = tuple(cl_buffer_datatype_dict.keys()) + def astype(self, dtype: type): if dtype not in _supported_numeric_types: raise ValueError( @@ -38,6 +39,7 @@ def astype(self, dtype: type): copy(input_image=self, output_image=result) return result + def max(self, axis: Optional[int] = None, out=None): from ._tier2 import maximum_of_all_pixels from ._tier1 import maximum_x_projection @@ -64,8 +66,8 @@ def max(self, axis: Optional[int] = None, out=None): out = result return result -def min(self, axis: Optional[int] = None, out=None): +def min(self, axis: Optional[int] = None, out=None): from ._tier2 import minimum_of_all_pixels from ._tier1 import minimum_x_projection from ._tier1 import minimum_y_projection @@ -89,6 +91,7 @@ def min(self, axis: Optional[int] = None, out=None): np.copyto(out, pull(result).astype(out.dtype)) return result + def sum(self, axis: Optional[int] = None, out=None): from ._tier2 import sum_of_all_pixels from ._tier1 import sum_x_projection @@ -113,6 +116,7 @@ def sum(self, axis: Optional[int] = None, out=None): np.copyto(out, pull(result).astype(out.dtype)) return result + def __iadd__(x1, x2): from ._tier1 import copy @@ -126,6 +130,7 @@ def __iadd__(x1, x2): return add_images_weighted(temp, x2, x1, factor0=1, factor1=1) + def __sub__(x1, x2): if isinstance(x2, _supported_numeric_types): from ._tier1 import add_image_and_scalar @@ -136,6 +141,7 @@ def __sub__(x1, x2): return add_images_weighted(x1, x2, factor0=1, factor1=-1) + def __div__(x1, x2): if isinstance(x2, _supported_numeric_types): from ._tier1 import multiply_image_and_scalar @@ -146,9 +152,11 @@ def __div__(x1, x2): return divide_images(x1, x2) + def __truediv__(x1, x2): return x1.__div__(x2) + def __idiv__(x1, x2): from ._tier1 import copy @@ -162,9 +170,11 @@ def __idiv__(x1, x2): return divide_images(temp, x2, x1) + def __itruediv__(x1, x2): return x1.__idiv__(x2) + def __mul__(x1, x2): if isinstance(x2, _supported_numeric_types): from ._tier1 import multiply_image_and_scalar @@ -175,6 +185,7 @@ def __mul__(x1, x2): return multiply_images(x1, x2) + def __imul__(x1, x2): from ._tier1 import copy @@ -188,6 +199,7 @@ def __imul__(x1, x2): return multiply_images(temp, x2, x1) + def __gt__(x1, x2): if isinstance(x2, _supported_numeric_types): from ._tier1 import greater_constant @@ -198,6 +210,7 @@ def __gt__(x1, x2): return greater(x1, x2) + def __ge__(x1, x2): if isinstance(x2, _supported_numeric_types): from ._tier1 import greater_or_equal_constant @@ -208,6 +221,7 @@ def __ge__(x1, x2): return greater_or_equal(x1, x2) + def __lt__(x1, x2): if isinstance(x2, _supported_numeric_types): from ._tier1 import smaller_constant @@ -218,6 +232,7 @@ def __lt__(x1, x2): return smaller(x1, x2) + def __le__(x1, x2): if isinstance(x2, _supported_numeric_types): from ._tier1 import smaller_or_equal_constant @@ -228,6 +243,7 @@ def __le__(x1, x2): return smaller_or_equal(x1, x2) + def __eq__(x1, x2): if isinstance(x2, _supported_numeric_types): from ._tier1 import equal_constant @@ -238,6 +254,7 @@ def __eq__(x1, x2): return equal(x1, x2) + def __ne__(x1, x2): if isinstance(x2, _supported_numeric_types): from ._tier1 import not_equal_constant @@ -248,6 +265,7 @@ def __ne__(x1, x2): return not_equal(x1, x2) + def __pow__(x1, x2): if isinstance(x2, _supported_numeric_types): from ._tier1 import power @@ -258,6 +276,7 @@ def __pow__(x1, x2): return power_images(x1, x2) + def __ipow__(x1, x2): from ._tier1 import copy @@ -273,7 +292,7 @@ def __ipow__(x1, x2): def __iter__(self): - class MyIterator(): + class MyIterator: def __init__(self, image): self.image = image self._iter_index = 0 @@ -282,6 +301,7 @@ def __next__(self): import numpy as np from ._memory import create from ._tier1 import copy_slice + if not hasattr(self, "_iter_index"): self._iter_index = 0 if self._iter_index < self.image.shape[0]: @@ -296,12 +316,14 @@ def __next__(self): return result else: raise StopIteration + return MyIterator(self) def __getitem__(self, index): raise NotImplementedError("Not implemented yet.") + def __setitem__(self, index, value): raise NotImplementedError("Not implemented yet.") @@ -317,91 +339,105 @@ def _plt_to_png(self): from io import BytesIO with BytesIO() as file_obj: - plt.savefig(file_obj, format='png') - plt.close() # supress plot output + plt.savefig(file_obj, format="png") + plt.close() # supress plot output file_obj.seek(0) png = file_obj.read() - return png + return png def _png_to_html(self, png): import base64 - url = 'data:image/png;base64,' + base64.b64encode(png).decode('utf-8') - return f'' + + url = "data:image/png;base64," + base64.b64encode(png).decode("utf-8") + return f'' def _repr_html_(self): - """HTML representation of the image object for IPython. - Returns - ------- - HTML text with the image and some properties. - """ - import numpy as np - import matplotlib.pyplot as plt - from ._functionalities import imshow + """HTML representation of the image object for IPython. + Returns + ------- + HTML text with the image and some properties. + """ + import numpy as np + import matplotlib.pyplot as plt + from ._functionalities import imshow - size_in_pixels = np.prod(self.size) - size_in_bytes = size_in_pixels * self.dtype.itemsize + size_in_pixels = np.prod(self.size) + size_in_bytes = size_in_pixels * self.dtype.itemsize - labels = (self.dtype == np.uint32) + labels = self.dtype == np.uint32 - # In case the image is 2D, 3D and larger than 100 pixels, turn on fancy view - if len(self.shape) in (2, 3) and size_in_pixels >= 100: - import matplotlib.pyplot as plt - imshow(self, - labels=labels, - continue_drawing=True, - colorbar=not labels) - image = self._png_to_html(self._plt_to_png()) - else: - return "
cle.array(" + str(np.asarray(self)) + ", dtype=" + str(self.dtype) + ")
" - - units = ['B', 'kB', 'MB', 'GB', 'TB', 'PB'] - unit_index = 0 - while size_in_bytes > 1024 and unit_index < len(units) - 1: - size_in_bytes /= 1024 - unit_index += 1 - size = "{:.1f}".format(size_in_bytes) + " " + units[unit_index] - - histogram = "" - if size_in_bytes < 100 * 1024 * 1024: - if not labels: - - from ._tier3 import histogram - - num_bins = 32 - h = np.asarray(histogram(self, nbins=num_bins, min=self.min(), max=self.max())) - plt.figure(figsize=(1.8, 1.2)) - plt.bar(range(0, len(h)), h) - # hide axis text - # https://stackoverflow.com/questions/2176424/hiding-axis-text-in-matplotlib-plots - # https://pythonguides.com/matplotlib-remove-tick-labels - frame1 = plt.gca() - frame1.axes.xaxis.set_ticklabels([]) - frame1.axes.yaxis.set_ticklabels([]) - plt.tick_params(left=False, bottom=False) - histogram = self._png_to_html(self._plt_to_png()) - min_max = "min" + str(self.min()) + "" + \ - "max" + str(self.max()) + "" - else: - min_max = "" - all = [ - "", - "", - "", - "", - "", - "
", - image, - "", - "cle._ image
", - "", - "", - "", - "", - min_max, - "
shape" + str(self.shape).replace(" ", " ") + "
dtype" + str(self.dtype) + "
size" + size + "
", - histogram, - "
", - ] - return "\n".join(all) + # In case the image is 2D, 3D and larger than 100 pixels, turn on fancy view + if len(self.shape) in (2, 3) and size_in_pixels >= 100: + import matplotlib.pyplot as plt + + imshow(self, labels=labels, continue_drawing=True, colorbar=not labels) + image = self._png_to_html(self._plt_to_png()) + else: + return ( + "
cle.array("
+            + str(np.asarray(self))
+            + ", dtype="
+            + str(self.dtype)
+            + ")
" + ) + + units = ["B", "kB", "MB", "GB", "TB", "PB"] + unit_index = 0 + while size_in_bytes > 1024 and unit_index < len(units) - 1: + size_in_bytes /= 1024 + unit_index += 1 + size = "{:.1f}".format(size_in_bytes) + " " + units[unit_index] + + histogram = "" + if size_in_bytes < 100 * 1024 * 1024: + if not labels: + from ._tier3 import histogram + + num_bins = 32 + h = np.asarray( + histogram(self, nbins=num_bins, min=self.min(), max=self.max()) + ) + plt.figure(figsize=(1.8, 1.2)) + plt.bar(range(0, len(h)), h) + # hide axis text + # https://stackoverflow.com/questions/2176424/hiding-axis-text-in-matplotlib-plots + # https://pythonguides.com/matplotlib-remove-tick-labels + frame1 = plt.gca() + frame1.axes.xaxis.set_ticklabels([]) + frame1.axes.yaxis.set_ticklabels([]) + plt.tick_params(left=False, bottom=False) + histogram = self._png_to_html(self._plt_to_png()) + min_max = ( + "min" + + str(self.min()) + + "" + + "max" + + str(self.max()) + + "" + ) + else: + min_max = "" + all = [ + "", + "", + "", + '", + "", + "
", + image, + "', + 'cle._ image
', + "", + "", + "", + "", + min_max, + "
shape" + + str(self.shape).replace(" ", " ") + + "
dtype" + str(self.dtype) + "
size" + size + "
", + histogram, + "
", + ] + return "\n".join(all)