diff --git a/pluggy.py b/pluggy.py index fa30fcdb..9864c98d 100644 --- a/pluggy.py +++ b/pluggy.py @@ -274,10 +274,12 @@ def __init__(self, project_name, implprefix=None): self.trace = _TagTracer().get("pluginmanage") self.hook = _HookRelay(self.trace.root.get("hook")) self._implprefix = implprefix - self._inner_hookexec = lambda hook, methods, kwargs: \ - _MultiCall( - methods, kwargs, specopts=hook.spec_opts, hook=hook - ).execute() + self._inner_hookexec = lambda hook, methods, kwargs: _MultiCall( + methods, + kwargs, + firstresult=hook.spec.opts['firstresult'] if hook.spec else False, + hook=hook + ).execute() def _hookexec(self, hook, methods, kwargs): # called from all hookcaller instances. @@ -424,7 +426,7 @@ def _verify_hook(self, hook, hookimpl): (hookimpl.plugin_name, hook.name)) # positional arg checking - notinspec = set(hookimpl.argnames) - set(hook.argnames) + notinspec = set(hookimpl.argnames) - set(hook.spec.argnames) if notinspec: raise PluginValidationError( "Plugin %r for hook %r\nhookimpl definition: %s\n" @@ -517,8 +519,8 @@ def subset_hook_caller(self, name, remove_plugins): orig = getattr(self.hook, name) plugins_to_remove = [plug for plug in remove_plugins if hasattr(plug, name)] if plugins_to_remove: - hc = _HookCaller(orig.name, orig._hookexec, orig._specmodule_or_class, - orig.spec_opts) + hc = _HookCaller(orig.name, orig._hookexec, orig.spec.namespace, + orig.spec.opts) for hookimpl in (orig._wrappers + orig._nonwrappers): plugin = hookimpl.plugin if plugin not in plugins_to_remove: @@ -539,17 +541,18 @@ class _MultiCall(object): # so we can remove it soon, allowing to avoid the below recursion # in execute() and simplify/speed up the execute loop. - def __init__(self, hook_impls, kwargs, specopts={}, hook=None): - self.hook = hook + def __init__(self, hook_impls, kwargs, firstresult=False, hook=None): self.hook_impls = hook_impls self.caller_kwargs = kwargs # come from _HookCaller.__call__() self.caller_kwargs["__multicall__"] = self - self.specopts = hook.spec_opts if hook else specopts + self.firstresult = firstresult + self.hook = hook + self.spec = hook.spec if hook else None def execute(self): caller_kwargs = self.caller_kwargs self.results = results = [] - firstresult = self.specopts.get("firstresult") + firstresult = self.firstresult while self.hook_impls: hook_impl = self.hook_impls.pop() @@ -560,8 +563,10 @@ def execute(self): if argname not in caller_kwargs: raise HookCallError( "hook call must provide argument %r" % (argname,)) + if hook_impl.hookwrapper: return _wrapped_call(hook_impl.function(*args), self.execute) + res = hook_impl.function(*args) if res is not None: if firstresult: @@ -645,28 +650,23 @@ def __init__(self, name, hook_execute, specmodule_or_class=None, spec_opts=None) self._wrappers = [] self._nonwrappers = [] self._hookexec = hook_execute - self.argnames = None - self.kwargnames = None + self.spec = None + self._call_history = None if specmodule_or_class is not None: assert spec_opts is not None self.set_specification(specmodule_or_class, spec_opts) def has_spec(self): - return hasattr(self, "_specmodule_or_class") + return self.spec is not None def set_specification(self, specmodule_or_class, spec_opts): assert not self.has_spec() - self._specmodule_or_class = specmodule_or_class - specfunc = getattr(specmodule_or_class, self.name) - # get spec arg signature - argnames, self.kwargnames = varnames(specfunc) - self.argnames = ["__multicall__"] + list(argnames) - self.spec_opts = spec_opts + self.spec = HookSpec(specmodule_or_class, self.name, spec_opts) if spec_opts.get("historic"): self._call_history = [] def is_historic(self): - return hasattr(self, "_call_history") + return self._call_history is not None def _remove_plugin(self, plugin): def remove(wrappers): @@ -702,8 +702,8 @@ def __repr__(self): def __call__(self, **kwargs): assert not self.is_historic() - if self.argnames: - notincall = set(self.argnames) - set(['__multicall__']) - set( + if self.spec: + notincall = set(self.spec.argnames) - set(['__multicall__']) - set( kwargs.keys()) if notincall: warnings.warn( @@ -749,6 +749,16 @@ def _maybe_apply_history(self, method): proc(res[0]) +class HookSpec(object): + def __init__(self, namespace, name, opts): + self.namespace = namespace + self.function = function = getattr(namespace, name) + self.name = name + self.argnames, self.kwargnames = varnames(function) + self.opts = opts + self.argnames = ["__multicall__"] + list(self.argnames) + + class HookImpl(object): def __init__(self, plugin, plugin_name, function, hook_impl_opts): self.function = function diff --git a/testing/benchmark.py b/testing/benchmark.py index e5888a5f..d2d84afc 100644 --- a/testing/benchmark.py +++ b/testing/benchmark.py @@ -13,7 +13,7 @@ def MC(methods, kwargs, firstresult=False): for method in methods: f = HookImpl(None, "", method, method.example_impl) hookfuncs.append(f) - return _MultiCall(hookfuncs, kwargs, {"firstresult": firstresult}) + return _MultiCall(hookfuncs, kwargs, firstresult=firstresult) @hookimpl diff --git a/testing/test_method_ordering.py b/testing/test_method_ordering.py index 7c87702a..8b2b321a 100644 --- a/testing/test_method_ordering.py +++ b/testing/test_method_ordering.py @@ -163,9 +163,9 @@ def he_myhook3(arg1): pass pm.add_hookspecs(HookSpec) - assert not pm.hook.he_myhook1.spec_opts["firstresult"] - assert pm.hook.he_myhook2.spec_opts["firstresult"] - assert not pm.hook.he_myhook3.spec_opts["firstresult"] + assert not pm.hook.he_myhook1.spec.opts["firstresult"] + assert pm.hook.he_myhook2.spec.opts["firstresult"] + assert not pm.hook.he_myhook3.spec.opts["firstresult"] @pytest.mark.parametrize('name', ["hookwrapper", "optionalhook", "tryfirst", "trylast"]) diff --git a/testing/test_multicall.py b/testing/test_multicall.py index 7aff8c33..26aeee6a 100644 --- a/testing/test_multicall.py +++ b/testing/test_multicall.py @@ -22,7 +22,7 @@ def MC(methods, kwargs, firstresult=False): for method in methods: f = HookImpl(None, "", method, method.example_impl) hookfuncs.append(f) - return _MultiCall(hookfuncs, kwargs, {"firstresult": firstresult}) + return _MultiCall(hookfuncs, kwargs, firstresult=firstresult) def test_call_passing():