diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 11608503c7..3b9127c715 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -11,6 +11,7 @@ from __future__ import annotations +import inspect import itertools import random import warnings @@ -1744,22 +1745,69 @@ def scale_affine(spatial_size, new_spatial_size, centered: bool = True): def attach_hook(func, hook, mode="pre"): """ Adds `hook` before or after a `func` call. If mode is "pre", the wrapper will call hook then func. - If the mode is "post", the wrapper will call func then hook. + If the mode is "post", the wrapper will call func then hook. In the case that additional arguments are + passed with the function 'func', the hook function will be called with any of those additional arguments + that the hook also supports. Such additional arguments must have the same name on both functions to be + matched. Unmatched arguments on 'func' are ignored when calling 'hook'. """ supported = {"pre", "post"} - if look_up_option(mode, supported) == "pre": - _hook, _func = hook, func - else: - _hook, _func = func, hook + _mode = look_up_option(mode, supported) + + def key_in_args(args, k): + return any(k == a for a in args.args) + + def index_of_key(args, k): + return args.args.index(k) + + def param_has_default(args, k): + if args.defaults is None: + return False + return index_of_key(args, k) >= len(args.args) - len(args.defaults) + + def param_default(args, k): + if args.defaults is None: + raise ValueError(f"Parameter {k} has no default") + d_k = len(args.args) - index_of_key(args, k) - 1 + if d_k >= len(args.defaults): + raise ValueError(f"Parameter {k} has no default") + return args.defaults[d_k] + + def key_at_index(args, i): + return args.args[i] + + f_args = inspect.getfullargspec(func) + h_args = inspect.getfullargspec(hook) @wraps(func) - def wrapper(inst, data): - data = _hook(inst, data) - return _func(inst, data) + def wrapper(inst, data, *args, **kwargs): + h_kwargs = dict() + + # iterate over the positional args that the wrapper was called with, getting their names. + # add any values for parameter names that are also in the hook function's names + for i_a, a in enumerate(args[2:]): + k = key_at_index(f_args, i_a) + if key_in_args(h_args, k): + h_kwargs[k] = a + + # go over parameters in the keyword args, adding any values for parameter names that are also in the hook function's names + for k, v in kwargs.items(): + if key_in_args(h_args, k): + h_kwargs[k] = v + + # handle the corner case where there is a parameter without a default on _hook that has a default on _func, but that hasn't + # been set by the caller. In this case, we get the default for that parameter on _func and pass it to _hook + for k in h_args.args: + if param_has_default(h_args, k) is False and k not in h_kwargs and param_has_default(f_args, k) is True: + h_kwargs[k] = param_default(f_args, k) + + if _mode == "pre": + return func(inst, hook(inst, data, **h_kwargs), *args, **kwargs) + return hook(inst, func(inst, data, *args, **kwargs), **h_kwargs) return wrapper + def sync_meta_info(key, data_dict, t: bool = True): """ Given the key, sync up between metatensor `data_dict[key]` and meta_dict `data_dict[key_transforms/meta_dict]`.