Skip to content

Commit

Permalink
Merge pull request #184 from clEsperanto/improve-code-quality
Browse files Browse the repository at this point in the history
Improve code quality
  • Loading branch information
StRigaud authored May 17, 2024
2 parents 054aaa9 + 534e1f1 commit 328e845
Show file tree
Hide file tree
Showing 15 changed files with 1,098 additions and 930 deletions.
16 changes: 8 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
rev: v4.5.0
hooks:
- id: check-toml
- id: check-yaml
Expand All @@ -16,18 +16,18 @@ repos:

# black - code formatting
- repo: https://github.com/psf/black
rev: 22.8.0
rev: 24.4.2
hooks:
- id: black

# # isort - import sorting
# - repo: https://github.com/pycqa/isort
# rev: 5.12.0
# hooks:
# - id: isort
# isort - import sorting
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort

# flake8 - linting
- repo: https://github.com/pycqa/flake8
rev: 5.0.4
rev: 7.0.0
hooks:
- id: flake8
6 changes: 2 additions & 4 deletions pyclesperanto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
)
from ._functionalities import execute, imshow, list_operations, native_execute
from ._memory import create, create_like, pull, push

from ._tier1 import *
from ._tier2 import *
from ._tier3 import *
Expand All @@ -21,11 +20,10 @@
from ._tier6 import *
from ._tier7 import *
from ._tier8 import *

from ._interroperability import *

from ._version import CLIC_VERSION as __clic_version__
from ._version import COMMON_ALIAS as __common_alias__
from ._version import VERSION as __version__

from ._interroperability import * # isort:skip

default_initialisation()
16 changes: 11 additions & 5 deletions pyclesperanto/_array.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import Union
from typing import Optional, Union

import numpy as np

from . import _operators
from ._core import Device, get_device
from ._pyclesperanto import _Array as Array
from ._utils import _assert_supported_dtype

from . import _operators


def _prepare_array(arr) -> np.ndarray:
"""Converts a given array to a numpy array with C memory layout.
Expand Down Expand Up @@ -37,7 +36,12 @@ def __repr__(self) -> str:
return repr_str[:-1] + f", {extra_info})"


def set(self, array: np.ndarray, origin: tuple = None, region: tuple = None) -> None:
def set(
self,
array: np.ndarray,
origin: Optional[tuple] = None,
region: Optional[tuple] = None,
) -> None:
"""Set the content of the Array to the given numpy array.
Parameters
Expand Down Expand Up @@ -73,7 +77,9 @@ def set(self, array: np.ndarray, origin: tuple = None, region: tuple = None) ->
return self


def get(self, origin: tuple = None, region: tuple = None) -> np.ndarray:
def get(
self, origin: Optional[tuple] = None, region: Optional[tuple] = None
) -> np.ndarray:
"""Get the content of the Array into a numpy array.
Parameters
Expand Down
142 changes: 77 additions & 65 deletions pyclesperanto/_decorators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
from functools import wraps
from typing import Callable, Optional
from typing import Any, Callable, List, Optional

from toolz import curry

Expand All @@ -9,89 +9,101 @@
from ._memory import push


class CallableFunction:
"""A class representing a callable function."""

def __call__(self, *args, **kwargs):
pass


class PluginFunction(CallableFunction):
"""A class representing a plugin function."""

def __init__(
self,
function: Callable[..., Any],
category: Optional[List] = None,
):
self.function = function
self.fullargspec = inspect.getfullargspec(function)
self.category = category

def __call__(self, *args, **kwargs):
return self.function(*args, **kwargs)


@curry
def plugin_function(
function: Callable,
category: Optional[list] = None,
priority: int = 0,
function: PluginFunction,
category: Optional[List] = None,
) -> Callable:
"""Function decorator to ensure correct types and values of all parameters.
"""A decorator for kernels functions.
The given input parameters are either of type OpenCL data/image/buffer (which the GPU
understands) or are converted to this type (see push function). If output
parameters of type Image are not set, an empty image is created and
handed over.
The decorator allocate a device if None is provided, and push the input image to the device
if it is not already there.
This function can be extended to support more functionalities in the future if we need to automatised
more behaviours on the kernels calls.
Parameters
----------
function : callable
The function to be executed on the GPU.
output_creator : callable, optional
A function to create an output cleImage given an input cleImage. By
default, we create output images of the same shape and type as input
images.
device_selector : callable, optional
A function to select a device. By default, we use the current device instance.
category : list of str, optional
A list of category names the function is associated with
priority : int, optional
can be used in lists of multiple operations to differentiate multiple operations that fulfill the same purpose
but better/faster/more general.
function : Callable[..., Any]
The function to be decorated.
category : Optional[List], optional
The category of the function, by default None.
Returns
-------
worker_function : callable
The actual function call that will be executed, magically creating
output arguments of the correct type.
Callable
The decorated function.
"""
function = PluginFunction(function, category)

function.fullargspec = inspect.getfullargspec(function)
function.category = category
function.priority = priority

@wraps(function)
@wraps(function.function)
def worker_function(*args, **kwargs):
sig = inspect.signature(function)
# create mapping from position and keyword arguments to parameters
# will raise a TypeError if the provided arguments do not match the signature
# https://docs.python.org/3/library/inspect.html#inspect.Signature.bind
sig = inspect.signature(function.function)
bound = sig.bind(*args, **kwargs)
# set default values for missing arguments
# https://docs.python.org/3/library/inspect.html#inspect.BoundArguments.apply_defaults
bound.apply_defaults()

args_list = function.fullargspec.args
index = next(
(i for i, element in enumerate(args_list) if "input_image" in element), -1
)
arg_name = args_list[index]

# copy images to GPU, and create output array if necessary
for key, value in bound.arguments.items():
if (
is_image(value)
and key in sig.parameters
and sig.parameters[key].annotation is Image
):
bound.arguments[key] = push(value)
if (
key in sig.parameters
and sig.parameters[key].annotation is Device
and value is None
):
input_image = bound.arguments[arg_name]
bound.arguments[key] = input_image.device

# call the decorated function
result = function(*bound.args, **bound.kwargs)

# # Cast the result as an Array if it is not already
# if not isinstance(result, _Array):
# result = _Array(result)
arg_name = get_input_image_arg_name(function)
process_arguments(function, bound, arg_name)

result = function.function(*bound.args, **bound.kwargs)

return result

# this is necessary to obfuscate pyclesperanto's internal structure
worker_function.__module__ = "pyclesperanto"

return worker_function


def get_input_image_arg_name(function: PluginFunction) -> str:
"""Get the name of the input image argument."""
args_list = function.fullargspec.args
index = next(
(i for i, element in enumerate(args_list) if "input_image" in element), -1
)
if index == -1:
raise NotImplementedError(
f"Wrong usage of decorator for the function {function.function.__name__}"
)
return args_list[index]


def process_arguments(
function: PluginFunction, bound: inspect.BoundArguments, input_image_arg_name: str
):
sig = inspect.signature(function.function)
for key, value in bound.arguments.items():
if (
is_image(value)
and key in sig.parameters
and sig.parameters[key].annotation is Image
):
bound.arguments[key] = push(value)
if (
key in sig.parameters
and sig.parameters[key].annotation is Optional[Device]
and value is None
):
input_image = bound.arguments[input_image_arg_name]
bound.arguments[key] = input_image.device
25 changes: 13 additions & 12 deletions pyclesperanto/_functionalities.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional, Union

import numpy as np
from matplotlib.colors import ListedColormap

from ._array import Array, Image
from ._core import Device, get_device
Expand Down Expand Up @@ -161,7 +162,7 @@ def imshow(
color_map: Optional[str] = None,
plot=None,
colorbar: Optional[bool] = False,
colormap: Union[str, None] = None,
colormap: Union[str, ListedColormap, None] = None,
alpha: Optional[float] = None,
continue_drawing: Optional[bool] = False,
):
Expand Down Expand Up @@ -209,13 +210,8 @@ def imshow(
if colormap is None:
colormap = color_map

if colormap is None:
colormap = "Greys_r"

cmap = colormap
if labels:
if not hasattr(imshow, "labels_cmap"):
from matplotlib.colors import ListedColormap
if not hasattr(imshow, "colormap"):
from numpy.random import MT19937, RandomState, SeedSequence

rs = RandomState(MT19937(SeedSequence(3)))
Expand All @@ -226,14 +222,17 @@ def imshow(
lut[2] = [1.0, 0.4980392156862745, 0.054901960784313725]
lut[3] = [0.17254901960784313, 0.6274509803921569, 0.17254901960784313]
lut[4] = [0.8392156862745098, 0.15294117647058825, 0.1568627450980392]
imshow.labels_cmap = ListedColormap(lut)
colormap = ListedColormap(lut)

cmap = imshow.labels_cmap
if min_display_intensity is None:
min_display_intensity = 0
if max_display_intensity is None:
max_display_intensity = 65536

if colormap is None:
colormap = "Greys_r"

cmap = colormap
if plot is None:
import matplotlib.pyplot as plt

Expand Down Expand Up @@ -265,7 +264,8 @@ def imshow(


def operations(
must_have_categories: list = None, must_not_have_categories: list = None
must_have_categories: Optional[list] = None,
must_not_have_categories: Optional[list] = None,
) -> dict:
"""Retrieve a dictionary of operations, which can be filtered by annotated categories.
Expand All @@ -292,11 +292,12 @@ def operations(
import pyclesperanto as cle

# retrieve all operations and cache the result for later reuse
operation_list = []
if not hasattr(operations, "_all") or operations._all is None:
operations._all = getmembers(cle, isfunction)
operation_list = getmembers(cle, isfunction)

# filter operations according to given constraints
for operation_name, operation in operations._all:
for operation_name, operation in operation_list:
keep_it = True
if hasattr(operation, "categories") and operation.categories is not None:
if must_have_categories is not None:
Expand Down
Loading

0 comments on commit 328e845

Please sign in to comment.