Using MetaAdam results in 'RuntimeError: Trying to backward through the graph a second time' #191
-
Hi! I would like to build on top of your Meta-Gradient RL example with the MetaAdam optimizer. However, if I simply replace the MetaSGD optimizer in your code with
Adding Here is the full code to reproduce the problem: import torch
import torch.nn as nn
import torch.nn.functional as F
import torchopt
def test_gamma():
class Rollout:
@staticmethod
def get():
out = torch.empty(5, 2)
out[:, 0] = torch.randn(5)
out[:, 1] = 0.1 * torch.ones(5)
label = torch.arange(0, 10)
return out.view(10, 1), F.one_hot(label, 10)
@staticmethod
def rollout(trajectory, gamma):
out = [trajectory[-1]]
for i in reversed(range(9)):
out.append(trajectory[i] + gamma[i] * out[-1].clone().detach_())
out.reverse()
return torch.hstack(out).view(10, 1)
class ValueNetwork(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
torch.manual_seed(0)
inner_iters = 1
outer_iters = 10000
net = ValueNetwork()
inner_optimizer = torchopt.MetaAdam(net, lr=5e-1, moment_requires_grad=False)
gamma = torch.zeros(9, requires_grad=True)
meta_optimizer = torchopt.SGD([gamma], lr=5e-1)
net_state = torchopt.extract_state_dict(net)
for i in range(outer_iters):
for _ in range(inner_iters):
trajectory, state = Rollout.get()
backup = Rollout.rollout(trajectory, torch.sigmoid(gamma))
pred_value = net(state.float())
loss = F.mse_loss(pred_value, backup)
inner_optimizer.step(loss)
trajectory, state = Rollout.get()
pred_value = net(state.float())
backup = Rollout.rollout(trajectory, torch.ones_like(gamma))
loss = F.mse_loss(pred_value, backup)
meta_optimizer.zero_grad()
loss.backward()
meta_optimizer.step()
torchopt.recover_state_dict(net, net_state)
torchopt.stop_gradient(net)
if i % 100 == 0:
with torch.no_grad():
print(f'epoch {i} | gamma: {torch.sigmoid(gamma)}')
if __name__ == '__main__':
test_gamma() Is this a bug or am I missing something? Thanks for your help! |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 1 reply
-
Hi @dierkes-j, thanks for raising this. You need to extract your model state at the beginning of each outer loop, and also recover the model state at the end of each inner loop. Here is the suggestion: import torch
import torch.nn as nn
import torch.nn.functional as F
import torchopt
def test_gamma():
class Rollout:
@staticmethod
def get():
out = torch.empty(5, 2)
out[:, 0] = torch.randn(5)
out[:, 1] = 0.1 * torch.ones(5)
label = torch.arange(0, 10)
return out.view(10, 1), F.one_hot(label, 10)
@staticmethod
def rollout(trajectory, gamma):
out = [trajectory[-1]]
for i in reversed(range(9)):
out.append(trajectory[i] + gamma[i] * out[-1].clone().detach_())
out.reverse()
return torch.hstack(out).view(10, 1)
class ValueNetwork(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
torch.manual_seed(0)
inner_iters = 1
outer_iters = 10000
net = ValueNetwork()
inner_optimizer = torchopt.MetaAdam(net, lr=5e-1, moment_requires_grad=False)
gamma = torch.zeros(9, requires_grad=True)
meta_optimizer = torchopt.SGD([gamma], lr=5e-1)
- net_state = torchopt.extract_state_dict(net)
for i in range(outer_iters):
+ net_state = torchopt.extract_state_dict(net)
for _ in range(inner_iters):
trajectory, state = Rollout.get()
backup = Rollout.rollout(trajectory, torch.sigmoid(gamma))
pred_value = net(state.float())
loss = F.mse_loss(pred_value, backup)
inner_optimizer.step(loss)
+ torchopt.recover_state_dict(net, net_state)
trajectory, state = Rollout.get()
pred_value = net(state.float())
backup = Rollout.rollout(trajectory, torch.ones_like(gamma))
loss = F.mse_loss(pred_value, backup)
meta_optimizer.zero_grad()
loss.backward()
meta_optimizer.step()
- torchopt.recover_state_dict(net, net_state)
torchopt.stop_gradient(net)
if i % 100 == 0:
with torch.no_grad():
print(f'epoch {i} | gamma: {torch.sigmoid(gamma)}')
if __name__ == '__main__':
test_gamma() Here is the results: $ python3 test_gamma.py
epoch 0 | gamma: tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000])
epoch 100 | gamma: tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000])
epoch 200 | gamma: tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000])
epoch 300 | gamma: tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000])
epoch 400 | gamma: tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000])
epoch 500 | gamma: tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000])
epoch 600 | gamma: tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000])
epoch 700 | gamma: tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000])
epoch 800 | gamma: tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000])
epoch 900 | gamma: tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000])
epoch 1000 | gamma: tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000])
... |
Beta Was this translation helpful? Give feedback.
-
Hi @XuehaiPan, thanks for your fast reply and help! I am a little bit confused about your reply. Doesn't moving the line From my understanding, I would simply omit the extraction and recovering of |
Beta Was this translation helpful? Give feedback.
-
I see, that makes sense to me! Thanks again for your help :) |
Beta Was this translation helpful? Give feedback.
Yes, you are correct.
The problem is that you should create an inner optimizer at the beginning of each outer loop. In your code snippet, the inner optimizer is shared across multiple outer loop optimization. You should either extract/recover and detach the state of the inner optimizer like you do for the network parameter, or recreate a new inner optimizer.