Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential solution for #6595 #6598

Draft
wants to merge 7 commits into
base: dev
Choose a base branch
from
Draft
64 changes: 56 additions & 8 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations

import inspect
import itertools
import random
import warnings
Expand Down Expand Up @@ -1744,22 +1745,69 @@
def attach_hook(func, hook, mode="pre"):
"""
Adds `hook` before or after a `func` call. If mode is "pre", the wrapper will call hook then func.
If the mode is "post", the wrapper will call func then hook.
If the mode is "post", the wrapper will call func then hook. In the case that additional arguments are
passed with the function 'func', the hook function will be called with any of those additional arguments
that the hook also supports. Such additional arguments must have the same name on both functions to be
matched. Unmatched arguments on 'func' are ignored when calling 'hook'.
"""
supported = {"pre", "post"}
if look_up_option(mode, supported) == "pre":
_hook, _func = hook, func
else:
_hook, _func = func, hook
_mode = look_up_option(mode, supported)

Check warning on line 1754 in monai/transforms/utils.py

View check run for this annotation

Codecov / codecov/patch

monai/transforms/utils.py#L1754

Added line #L1754 was not covered by tests

def key_in_args(args, k):
return any(k == a for a in args.args)

Check warning on line 1757 in monai/transforms/utils.py

View check run for this annotation

Codecov / codecov/patch

monai/transforms/utils.py#L1756-L1757

Added lines #L1756 - L1757 were not covered by tests

def index_of_key(args, k):
return args.args.index(k)

Check warning on line 1760 in monai/transforms/utils.py

View check run for this annotation

Codecov / codecov/patch

monai/transforms/utils.py#L1759-L1760

Added lines #L1759 - L1760 were not covered by tests

def param_has_default(args, k):
if args.defaults is None:
return False
return index_of_key(args, k) >= len(args.args) - len(args.defaults)

Check warning on line 1765 in monai/transforms/utils.py

View check run for this annotation

Codecov / codecov/patch

monai/transforms/utils.py#L1762-L1765

Added lines #L1762 - L1765 were not covered by tests

def param_default(args, k):
if args.defaults is None:
raise ValueError(f"Parameter {k} has no default")
d_k = len(args.args) - index_of_key(args, k) - 1
if d_k >= len(args.defaults):
raise ValueError(f"Parameter {k} has no default")
return args.defaults[d_k]

Check warning on line 1773 in monai/transforms/utils.py

View check run for this annotation

Codecov / codecov/patch

monai/transforms/utils.py#L1767-L1773

Added lines #L1767 - L1773 were not covered by tests

def key_at_index(args, i):
return args.args[i]

Check warning on line 1776 in monai/transforms/utils.py

View check run for this annotation

Codecov / codecov/patch

monai/transforms/utils.py#L1775-L1776

Added lines #L1775 - L1776 were not covered by tests

f_args = inspect.getfullargspec(func)
h_args = inspect.getfullargspec(hook)

Check warning on line 1779 in monai/transforms/utils.py

View check run for this annotation

Codecov / codecov/patch

monai/transforms/utils.py#L1778-L1779

Added lines #L1778 - L1779 were not covered by tests

@wraps(func)
def wrapper(inst, data):
data = _hook(inst, data)
return _func(inst, data)
def wrapper(inst, data, *args, **kwargs):
h_kwargs = dict()

Check warning on line 1783 in monai/transforms/utils.py

View check run for this annotation

Codecov / codecov/patch

monai/transforms/utils.py#L1782-L1783

Added lines #L1782 - L1783 were not covered by tests

# iterate over the positional args that the wrapper was called with, getting their names.
# add any values for parameter names that are also in the hook function's names
for i_a, a in enumerate(args[2:]):
k = key_at_index(f_args, i_a)
if key_in_args(h_args, k):
h_kwargs[k] = a

Check warning on line 1790 in monai/transforms/utils.py

View check run for this annotation

Codecov / codecov/patch

monai/transforms/utils.py#L1787-L1790

Added lines #L1787 - L1790 were not covered by tests

# go over parameters in the keyword args, adding any values for parameter names that are also in the hook function's names
for k, v in kwargs.items():
if key_in_args(h_args, k):
h_kwargs[k] = v

Check warning on line 1795 in monai/transforms/utils.py

View check run for this annotation

Codecov / codecov/patch

monai/transforms/utils.py#L1793-L1795

Added lines #L1793 - L1795 were not covered by tests

# handle the corner case where there is a parameter without a default on _hook that has a default on _func, but that hasn't
# been set by the caller. In this case, we get the default for that parameter on _func and pass it to _hook
for k in h_args.args:
if param_has_default(h_args, k) is False and k not in h_kwargs and param_has_default(f_args, k) is True:
h_kwargs[k] = param_default(f_args, k)

Check warning on line 1801 in monai/transforms/utils.py

View check run for this annotation

Codecov / codecov/patch

monai/transforms/utils.py#L1799-L1801

Added lines #L1799 - L1801 were not covered by tests

if _mode == "pre":
return func(inst, hook(inst, data, **h_kwargs), *args, **kwargs)
return hook(inst, func(inst, data, *args, **kwargs), **h_kwargs)

Check warning on line 1805 in monai/transforms/utils.py

View check run for this annotation

Codecov / codecov/patch

monai/transforms/utils.py#L1803-L1805

Added lines #L1803 - L1805 were not covered by tests

return wrapper



def sync_meta_info(key, data_dict, t: bool = True):
"""
Given the key, sync up between metatensor `data_dict[key]` and meta_dict `data_dict[key_transforms/meta_dict]`.
Expand Down
Loading