Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wrap_ad causes memory leaks when disabling gradient tracking #114

Open
bathal1 opened this issue Dec 15, 2022 · 1 comment
Open

wrap_ad causes memory leaks when disabling gradient tracking #114

bathal1 opened this issue Dec 15, 2022 · 1 comment
Labels
bug Something isn't working

Comments

@bathal1
Copy link
Contributor

bathal1 commented Dec 15, 2022

The problem

dr.wrap_ad allows for convenient interoperability between tensor frameworks for derivative tracking, however, it is sometimes useful to call such a wrapped function without tracking derivatives, e.g. in a with dr.suspend_grad(): scope.
An example of such a use case could be recomputing the loss value at regular intervals during an optimization for visualization purposes.

Source of the issue

The current implementation always sets the tensor's require_grad to True :

self.args_torch = drjit_to_torch(args, enable_grad=True)

This causes PyTorch to track derivatives through the computation, but backward is never called, nor can it be since gradient computation is disabled on the DrJIT side. As a consequence, an isolated PyTorch computation graph is stored and never flushed, as _torch.autograd.backward is never called.

This can cause severe leaks when using large computations in PyTorch, e.g. using a deep neural net.

While this could be fixed relatively easily on the user side by providing a flag to the wrapped function to manually disable gradients, this behavior is very error prone. It would be more desirable if we could merely use the grad_enabled flag of each DrJIT variable instead of the current:

b.requires_grad = _dr.grad_enabled(a) or (enable_grad and _dr.is_diff_v(a))

My understanding is that this is due to forcing detaching all the input parameters when calling the CustomOp.eval to avoid tracking derivatives there:

output = inst.eval(*_dr.detach(kwargs['args']))

One option to alleviate that would be to detach the output of eval, though that may lead to dangling parts of AD graphs being lost because of the computation in eval.

Another option would be to add a function is_grad_suspended to DrJIT, and introduce the following change:

diff --git a/drjit/router.py b/drjit/router.py
index d4cb079..4f4d502 100644
--- a/drjit/router.py
+++ b/drjit/router.py
@@ -5752,7 +5752,7 @@ def wrap_ad(source: str, target: str):
                     return {k: drjit_to_torch(v, enable_grad) for k, v in a.items()}
                 elif _dr.is_array_v(a) and _dr.is_tensor_v(a):
                     b = a.torch()
-                    b.requires_grad = _dr.grad_enabled(a) or (enable_grad and _dr.is_diff_v(a))
+                    b.requires_grad = _dr.grad_enabled(a) or (enable_grad and _dr.is_diff_v(a) and _dr.is_grad_suspended())
                     return b
                 elif _dr.is_diff_v(a) and a.IsFloat:
                     raise TypeError("wrap_ad(): differential input arguments "

Reproducing

The following snippet will print the amount of memory leaked at each step:

import drjit as dr
import torch

cnn = torch.nn.Sequential(
        torch.nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
        torch.nn.ReLU(inplace=True),
        torch.nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
        torch.nn.ReLU(inplace=True)
    ).cuda()

for p in cnn.parameters():
    p.requires_grad = False


@dr.wrap_ad(source="drjit", target="torch")
def loss_func(img):
    # Fix: manually disable gradients
    # img.requires_grad = False
    features = cnn(img.T[None, ...])
    x = torch.mean(features)
    return x.cpu()

img = dr.ones(dr.cuda.ad.TensorXf, (256, 256, 3))

losses = []
mem = 0
for i in range(1000):
    tmp = torch.cuda.memory_allocated(device=0)
    print(f"{tmp - mem} bytes leaked")
    mem = tmp
    with dr.suspend_grad():
        loss_clean = loss_func(img)
    losses.append(loss_clean)
@lynshwoo2022
Copy link

hi @bathal1 , I encountered some problems : mitsuba-renderer/mitsuba3#467, is this caused by the issue you mentioned?

@bathal1 bathal1 added the bug Something isn't working label Jan 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants