Skip to content

Commit

Permalink
ENH: Python dispatch on the first RequiredInputName
Browse files Browse the repository at this point in the history
Dispatch Python filter types based on the type of the first
ProcessObject RequiredInputName, which is exposed in Python via
GetRequiredInputName and which can be the ProcessObject
PrimaryInputName.

This helps to address uses cases such as
`itk.elastix_registration_method`, where you still want to infer the
filter type based on the input fixed image type, but it may be passed in
as a keyword argument, such as
`itk.elastix_registration_method(moving_image=moving_image, fixed_image=fixed_image)`.

Re: InsightSoftwareConsortium#4858 InsightSoftwareConsortium/ITKElastix#212
  • Loading branch information
thewtex authored and hjmjohnson committed Nov 8, 2024
1 parent 73841d0 commit 8185e10
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 13 deletions.
11 changes: 11 additions & 0 deletions Wrapping/Generators/Python/Tests/PythonTemplateTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,17 @@
median_kwarg = itk.MedianImageFilter.New(Input=reader.GetOutput())
assert itk.class_(median) == itk.class_(median_kwarg)

# filter type determined by the input passed as a primary input name input
median_primary_kwarg = itk.MedianImageFilter.New(Primary=reader.GetOutput())
assert itk.class_(median) == itk.class_(median_primary_kwarg)

# First RequiredInputName: "FixedImage"
fixed_image = reader.GetOutput()
moving_image = fixed_image
pde_registration = itk.PDEDeformableRegistrationFilter.New(
FixedImage=fixed_image, MovingImage=moving_image
)

# to a filter with a SetImage method
calculator = itk.MinimumMaximumImageCalculator[ImageType].New(reader)
# not GetImage() method here to verify it's the right image
Expand Down
12 changes: 12 additions & 0 deletions Wrapping/Generators/Python/Tests/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,18 @@ def custom_callback(name, progress):
image = itk.imread(filename, imageio=itk.PNGImageIO.New())
assert type(image) == itk.Image[itk.RGBPixel[itk.UC], 2]

# Python functional interface that determines the filter type
# based on the primary input (dispatch)
image = itk.imread(filename, itk.UC)
# positional argument
filtered_positional = itk.median_image_filter(image)
# required primary named input argument
filtered_kwarg = itk.median_image_filter(primary=image)
comparison = itk.comparison_image_filter(
filtered_positional, filtered_kwarg, verify_input_information=True
)
assert np.sum(comparison) == 0.0

# imread using a dicom series
image = itk.imread(sys.argv[8])
image0 = itk.imread(sys.argv[8], series_uid=0)
Expand Down
16 changes: 3 additions & 13 deletions Wrapping/Generators/Python/itk/support/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .helpers import wasm_type_from_image_type, image_type_from_wasm_type
from .helpers import wasm_type_from_mesh_type, mesh_type_from_wasm_type, python_to_js
from .helpers import wasm_type_from_pointset_type, pointset_type_from_wasm_type
from .helpers import snake_to_camel_case

from .xarray import xarray_from_image, image_from_xarray

Expand Down Expand Up @@ -1552,17 +1553,6 @@ def search(s: str, case_sensitive: bool = False) -> List[str]: # , fuzzy=True):
return res


def _snake_to_camel(keyword: str):
# Helpers for set_inputs snake case to CamelCase keyword argument conversion
_snake_underscore_re = re.compile("(_)([a-z0-9A-Z])")

def _underscore_upper(match_obj):
return match_obj.group(2).upper()

camel = keyword[0].upper()
if _snake_underscore_re.search(keyword[1:]):
return camel + _snake_underscore_re.sub(_underscore_upper, keyword[1:])
return camel + keyword[1:]


def set_inputs(
Expand Down Expand Up @@ -1656,7 +1646,7 @@ def SetInputs(self, *args, **kargs):
# (Ex: itk.ImageFileReader.UC2.New(SetFileName='image.png'))
if attribName not in ["auto_progress", "template_parameters"]:
if attribName.islower():
attribName = _snake_to_camel(attribName)
attribName = snake_to_camel_case(attribName)
attrib = getattr(new_itk_object, "Set" + attribName)

# Do not use try-except mechanism as this leads to
Expand Down Expand Up @@ -2162,7 +2152,7 @@ def ipython_kw_matches(text: str):
namespace = split_name_parts[:-1]
function_name = split_name_parts[-1]
# Find corresponding object name
object_name = _snake_to_camel(function_name)
object_name = snake_to_camel_case(function_name)
# Check that this object actually exists
try:
object_callable_match = ".".join(namespace + [object_name])
Expand Down
12 changes: 12 additions & 0 deletions Wrapping/Generators/Python/itk/support/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@
pass


def snake_to_camel_case(keyword: str):
# Helpers for set_inputs snake case to CamelCase keyword argument conversion
_snake_underscore_re = re.compile("(_)([a-z0-9A-Z])")

def _underscore_upper(match_obj):
return match_obj.group(2).upper()

camel = keyword[0].upper()
if _snake_underscore_re.search(keyword[1:]):
return camel + _snake_underscore_re.sub(_underscore_upper, keyword[1:])
return camel + keyword[1:]

def camel_to_snake_case(name):
snake = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
snake = re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake)
Expand Down
15 changes: 15 additions & 0 deletions Wrapping/Generators/Python/itk/support/template_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from itk.support import base
from itk.support.extras import output
from itk.support.types import itkCType
from itk.support.helpers import snake_to_camel_case
import math
from collections.abc import Mapping

Expand Down Expand Up @@ -718,6 +719,20 @@ def ttype_for_input_type(keys_l, input_type_l):
# try to find a type suitable for the input provided
input_type = output(cur).__class__
keys = ttype_for_input_type(keys, input_type)
else:
inst = self.values()[0].New()
if hasattr(inst, "GetRequiredInputNames"):
required_input_names = inst.GetRequiredInputNames()
if len(required_input_names) > 0:
primary_input_name = required_input_names[0]
kwargs_camel = {snake_to_camel_case(k): v for k,v in kwargs.items()}
if primary_input_name in kwargs_camel.keys():
input_type = output(kwargs_camel[primary_input_name]).__class__
keys = ttype_for_input_type(keys, input_type)
if not hasattr(inst, f"Set{primary_input_name}"):
arg_0 = kwargs_camel.pop(primary_input_name)
kwargs = kwargs_camel
args = (arg_0,) + args

if len(keys) == 0:
if not input_type:
Expand Down

0 comments on commit 8185e10

Please sign in to comment.