diff --git a/zennit/core.py b/zennit/core.py index a4b4475..e1eeb00 100644 --- a/zennit/core.py +++ b/zennit/core.py @@ -143,10 +143,15 @@ def wrapper(grad_input, grad_output): if not isinstance(input, tuple): input = (input,) - post_input = Identity.apply(*input) - post_input[0].grad_fn.register_hook(wrapper) - # work around to support in-place operations - post_input = tuple(elem.clone() for elem in post_input) + if input[0].grad_fn is not None: + # only if gradient required + post_input = Identity.apply(*input) + post_input[0].grad_fn.register_hook(wrapper) + # work around to support in-place operations + post_input = tuple(elem.clone() for elem in post_input) + else: + # no gradient required + post_input = input return post_input[0] if len(post_input) == 1 else post_input def post_forward(self, module, input, output): @@ -160,7 +165,9 @@ def wrapper(grad_input, grad_output): if not isinstance(output, tuple): output = (output,) - output[0].grad_fn.register_hook(wrapper) + if output[0].grad_fn is not None: + # only if gradient required + output[0].grad_fn.register_hook(wrapper) return output[0] if len(output) == 1 else output def pre_backward(self, module, grad_input, grad_output):