diff --git a/src/pluggy/_callers.py b/src/pluggy/_callers.py index 787f56ba..3872a32c 100644 --- a/src/pluggy/_callers.py +++ b/src/pluggy/_callers.py @@ -17,7 +17,6 @@ from ._result import Result from ._warnings import PluggyTeardownRaisedWarning - # Need to distinguish between old- and new-style hook wrappers. # Wrapping with a tuple is the fastest type-safe way I found to do it. Teardown = Union[ @@ -70,6 +69,11 @@ def _multicall( for hook_impl in reversed(hook_impls): try: args = [caller_kwargs[argname] for argname in hook_impl.argnames] + kwargs = { + k: v + for k, v in caller_kwargs.items() + if k in hook_impl.kwargnames + } except KeyError: for argname in hook_impl.argnames: if argname not in caller_kwargs: @@ -82,7 +86,7 @@ def _multicall( try: # If this cast is not valid, a type error is raised below, # which is the desired response. - res = hook_impl.function(*args) + res = hook_impl.function(*args, **kwargs) wrapper_gen = cast(Generator[None, Result[object], None], res) next(wrapper_gen) # first yield teardowns.append((wrapper_gen, hook_impl)) @@ -92,14 +96,14 @@ def _multicall( try: # If this cast is not valid, a type error is raised below, # which is the desired response. - res = hook_impl.function(*args) + res = hook_impl.function(*args, **kwargs) function_gen = cast(Generator[None, object, object], res) next(function_gen) # first yield teardowns.append(function_gen) except StopIteration: _raise_wrapfail(function_gen, "did not yield") else: - res = hook_impl.function(*args) + res = hook_impl.function(*args, **kwargs) if res is not None: results.append(res) if firstresult: # halt further impl calls diff --git a/testing/test_multicall.py b/testing/test_multicall.py index 7d8d8f28..8607efc4 100644 --- a/testing/test_multicall.py +++ b/testing/test_multicall.py @@ -50,8 +50,15 @@ def test_keyword_args_with_defaultargs() -> None: def f(x, z=1): return x + z - reslist = MC([f], dict(x=23, y=24)) - assert reslist == [24] + @hookimpl + def f2(x, y=1): + return x + y + + reslist = MC([f, f2], dict(x=23, y=24)) + assert reslist == [23 + 24, 23 + 1] + + reslist = MC([f2], dict(x=23)) + assert reslist == [23 + 1] def test_tags_call_error() -> None: