Skip to content

Commit

Permalink
Allow forward pass without gradient
Browse files Browse the repository at this point in the history
- previously, a crash would occur if a module with a registered rule
would be passed without requiring gradient
- this commit explicitly allows forward of a module with a registered
rule without requiring gradient computation
  • Loading branch information
chr5tphr committed Jul 2, 2021
1 parent 1b5cfae commit c984127
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions zennit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,15 @@ def wrapper(grad_input, grad_output):
if not isinstance(input, tuple):
input = (input,)

post_input = Identity.apply(*input)
post_input[0].grad_fn.register_hook(wrapper)
# work around to support in-place operations
post_input = tuple(elem.clone() for elem in post_input)
if input[0].grad_fn is not None:
# only if gradient required
post_input = Identity.apply(*input)
post_input[0].grad_fn.register_hook(wrapper)
# work around to support in-place operations
post_input = tuple(elem.clone() for elem in post_input)
else:
# no gradient required
post_input = input
return post_input[0] if len(post_input) == 1 else post_input

def post_forward(self, module, input, output):
Expand All @@ -160,7 +165,9 @@ def wrapper(grad_input, grad_output):
if not isinstance(output, tuple):
output = (output,)

output[0].grad_fn.register_hook(wrapper)
if output[0].grad_fn is not None:
# only if gradient required
output[0].grad_fn.register_hook(wrapper)
return output[0] if len(output) == 1 else output

def pre_backward(self, module, grad_input, grad_output):
Expand Down

0 comments on commit c984127

Please sign in to comment.