You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
dr.wrap_ad allows for convenient interoperability between tensor frameworks for derivative tracking, however, it is sometimes useful to call such a wrapped function without tracking derivatives, e.g. in a with dr.suspend_grad(): scope.
An example of such a use case could be recomputing the loss value at regular intervals during an optimization for visualization purposes.
Source of the issue
The current implementation always sets the tensor's require_grad to True :
This causes PyTorch to track derivatives through the computation, but backward is never called, nor can it be since gradient computation is disabled on the DrJIT side. As a consequence, an isolated PyTorch computation graph is stored and never flushed, as _torch.autograd.backward is never called.
This can cause severe leaks when using large computations in PyTorch, e.g. using a deep neural net.
While this could be fixed relatively easily on the user side by providing a flag to the wrapped function to manually disable gradients, this behavior is very error prone. It would be more desirable if we could merely use the grad_enabled flag of each DrJIT variable instead of the current:
One option to alleviate that would be to detach the output of eval, though that may lead to dangling parts of AD graphs being lost because of the computation in eval.
Another option would be to add a function is_grad_suspended to DrJIT, and introduce the following change:
diff --git a/drjit/router.py b/drjit/router.py
index d4cb079..4f4d502 100644
--- a/drjit/router.py+++ b/drjit/router.py@@ -5752,7 +5752,7 @@ def wrap_ad(source: str, target: str):
return {k: drjit_to_torch(v, enable_grad) for k, v in a.items()}
elif _dr.is_array_v(a) and _dr.is_tensor_v(a):
b = a.torch()
- b.requires_grad = _dr.grad_enabled(a) or (enable_grad and _dr.is_diff_v(a))+ b.requires_grad = _dr.grad_enabled(a) or (enable_grad and _dr.is_diff_v(a) and _dr.is_grad_suspended())
return b
elif _dr.is_diff_v(a) and a.IsFloat:
raise TypeError("wrap_ad(): differential input arguments "
Reproducing
The following snippet will print the amount of memory leaked at each step:
The problem
dr.wrap_ad
allows for convenient interoperability between tensor frameworks for derivative tracking, however, it is sometimes useful to call such a wrapped function without tracking derivatives, e.g. in awith dr.suspend_grad():
scope.An example of such a use case could be recomputing the loss value at regular intervals during an optimization for visualization purposes.
Source of the issue
The current implementation always sets the tensor's
require_grad
toTrue
:drjit/drjit/router.py
Line 5827 in ebeed9b
This causes PyTorch to track derivatives through the computation, but
backward
is never called, nor can it be since gradient computation is disabled on the DrJIT side. As a consequence, an isolated PyTorch computation graph is stored and never flushed, as_torch.autograd.backward
is never called.This can cause severe leaks when using large computations in PyTorch, e.g. using a deep neural net.
While this could be fixed relatively easily on the user side by providing a flag to the wrapped function to manually disable gradients, this behavior is very error prone. It would be more desirable if we could merely use the
grad_enabled
flag of each DrJIT variable instead of the current:drjit/drjit/router.py
Line 5755 in ebeed9b
My understanding is that this is due to forcing detaching all the input parameters when calling the
CustomOp.eval
to avoid tracking derivatives there:drjit/drjit/router.py
Line 5583 in ebeed9b
One option to alleviate that would be to detach the output of
eval
, though that may lead to dangling parts of AD graphs being lost because of the computation ineval
.Another option would be to add a function
is_grad_suspended
to DrJIT, and introduce the following change:Reproducing
The following snippet will print the amount of memory leaked at each step:
The text was updated successfully, but these errors were encountered: