Skip to content
This repository has been archived by the owner on Aug 31, 2021. It is now read-only.

Commit

Permalink
Fix PSD bug for quadprog, and adapt code to PyTorch 0.4.
Browse files Browse the repository at this point in the history
  • Loading branch information
David Lopez-Paz committed Oct 22, 2018
1 parent 4c2fae9 commit 34c6b8e
Show file tree
Hide file tree
Showing 10 changed files with 71 additions and 37 deletions.
5 changes: 5 additions & 0 deletions CODE_OF_CONDUCT.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Code of Conduct

Facebook has adopted a Code of Conduct that we expect project participants to adhere to.
Please read the [full text](https://code.fb.com/codeofconduct/)
so that you can understand what actions will and will not be tolerated.
32 changes: 32 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Contributing to GradientEpisodicMemory
We want to make contributing to this project as easy and transparent as
possible.

## Pull Requests
We actively welcome your pull requests.

1. Fork the repo and create your branch from `master`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.
5. Make sure your code lints.
6. If you haven't already, complete the Contributor License Agreement ("CLA").

## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Facebook's open source projects.

Complete your CLA here: <https://code.facebook.com/cla>

## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.

Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.

## License
By contributing to GradientEpisodicMemory, you agree that your contributions
will be licensed under the LICENSE file in the root directory of this source
tree.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@ Source code for [the paper](https://arxiv.org/abs/1706.08840):
```

To replicate the experiments, execute `./run_experiments.sh`.

This source code is released under a Attribution-NonCommercial 4.0 International
license, find out more about it [here](LICENSE).
8 changes: 3 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import numpy as np

import torch
from torch.autograd import Variable
from metrics.metrics import confusion_matrix

# continuum iterator #########################################################
Expand All @@ -26,8 +25,8 @@ def load_datasets(args):
n_inputs = d_tr[0][1].size(1)
n_outputs = 0
for i in range(len(d_tr)):
n_outputs = max(n_outputs, d_tr[i][2].max())
n_outputs = max(n_outputs, d_te[i][2].max())
n_outputs = max(n_outputs, d_tr[i][2].max().item())
n_outputs = max(n_outputs, d_te[i][2].max().item())
return d_tr, d_te, n_inputs, n_outputs + 1, len(d_tr)


Expand Down Expand Up @@ -112,7 +111,6 @@ def eval_tasks(model, tasks, args):
yb = y[b_from:b_to]
if args.cuda:
xb = xb.cuda()
xb = Variable(xb, volatile=True)
_, pb = torch.max(model(xb, t).data.cpu(), 1, keepdim=False)
rt += (pb == yb).float().sum()

Expand Down Expand Up @@ -142,7 +140,7 @@ def life_experience(model, continuum, x_te, args):
v_y = v_y.cuda()

model.train()
model.observe(Variable(v_x), t, Variable(v_y))
model.observe(v_x, t, v_y)

result_a.append(eval_tasks(model, x_te, args))
result_t.append(current_task)
Expand Down
2 changes: 1 addition & 1 deletion metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def confusion_matrix(result_t, result_a, fname=None):

baseline = result_a[0]
changes = torch.LongTensor(changes + [result_a.size(0)]) - 1
result = result_a.index(torch.LongTensor(changes))
result = result_a[changes]

# acc[t] equals result[t,t]
acc = result.diag()
Expand Down
13 changes: 6 additions & 7 deletions model/ewc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

import torch
from torch.autograd import Variable
from .common import MLP, ResNet18


Expand Down Expand Up @@ -76,12 +75,12 @@ def observe(self, x, t, y):

if self.is_cifar:
offset1, offset2 = self.compute_offsets(self.current_task)
self.bce((self.net(Variable(self.memx))[:, offset1: offset2]),
Variable(self.memy) - offset1).backward()
self.bce((self.net(self.memx)[:, offset1: offset2]),
self.memy - offset1).backward()
else:
self.bce(self(Variable(self.memx),
self.bce(self(self.memx,
self.current_task),
Variable(self.memy)).backward()
self.memy).backward()
self.fisher[self.current_task] = []
self.optpar[self.current_task] = []
for p in self.net.parameters():
Expand Down Expand Up @@ -113,8 +112,8 @@ def observe(self, x, t, y):
loss = self.bce(self(x, t), y)
for tt in range(t):
for i, p in enumerate(self.net.parameters()):
l = self.reg * Variable(self.fisher[tt][i])
l = l * (p - Variable(self.optpar[tt][i])).pow(2)
l = self.reg * self.fisher[tt][i]
l = l * (p - self.optpar[tt][i]).pow(2)
loss += l.sum()
loss.backward()
self.opt.step()
9 changes: 4 additions & 5 deletions model/gem.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

import numpy as np
import quadprog
Expand Down Expand Up @@ -68,7 +67,7 @@ def overwrite_grad(pp, newgrad, grad_dims):
cnt += 1


def project2cone2(gradient, memories, margin=0.5):
def project2cone2(gradient, memories, margin=0.5, eps=1e-3):
"""
Solves the GEM dual QP described in the paper given a proposed
gradient "gradient", and a memory of task gradients "memories".
Expand All @@ -82,7 +81,7 @@ def project2cone2(gradient, memories, margin=0.5):
gradient_np = gradient.cpu().contiguous().view(-1).double().numpy()
t = memories_np.shape[0]
P = np.dot(memories_np, memories_np.transpose())
P = 0.5 * (P + P.transpose())
P = 0.5 * (P + P.transpose()) + np.eye(t) * eps
q = np.dot(memories_np, gradient_np) * -1
G = np.eye(t)
h = np.zeros(t) + margin
Expand Down Expand Up @@ -183,9 +182,9 @@ def observe(self, x, t, y):
self.is_cifar)
ptloss = self.ce(
self.forward(
Variable(self.memory_data[past_task]),
self.memory_data[past_task],
past_task)[:, offset1: offset2],
Variable(self.memory_labs[past_task] - offset1))
self.memory_labs[past_task] - offset1)
ptloss.backward()
store_grad(self.parameters, self.grads, self.grad_dims,
past_task)
Expand Down
31 changes: 15 additions & 16 deletions model/icarl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

import torch
from torch.autograd import Variable

import numpy as np
import random
Expand Down Expand Up @@ -45,8 +44,8 @@ def __init__(self,
# setup losses
self.bce = torch.nn.CrossEntropyLoss()
self.kl = torch.nn.KLDivLoss() # for distillation
self.lsm = torch.nn.LogSoftmax()
self.sm = torch.nn.Softmax()
self.lsm = torch.nn.LogSoftmax(dim=1)
self.sm = torch.nn.Softmax(dim=1)

# memory
self.memx = None # stores raw inputs, PxD
Expand Down Expand Up @@ -76,28 +75,28 @@ def forward(self, x, t):
1.0 / self.nc_per_task)
if self.gpu:
out = out.cuda()
return Variable(out)
return out
means = torch.ones(self.nc_per_task, nd) * float('inf')
if self.gpu:
means = means.cuda()
offset1, offset2 = self.compute_offsets(t)
for cc in range(offset1, offset2):
means[cc -
offset1] = self.net(Variable(self.mem_class_x[cc])).data.mean(0)
offset1] = self.net(self.mem_class_x[cc]).data.mean(0)
classpred = torch.LongTensor(ns)
preds = self.net(x).data.clone()
for ss in range(ns):
dist = (means - preds[ss].expand(self.nc_per_task, nd)).norm(2, 1)
_, ii = dist.min(0)
ii = ii.squeeze()
classpred[ss] = ii[0] + offset1
classpred[ss] = ii.item() + offset1

out = torch.zeros(ns, self.n_classes)
if self.gpu:
out = out.cuda()
for ss in range(ns):
out[ss, classpred[ss]] = 1
return Variable(out) # return 1-of-C code, ns x nc
return out # return 1-of-C code, ns x nc

def forward_training(self, x, t):
output = self.net(x)
Expand Down Expand Up @@ -137,15 +136,15 @@ def observe(self, x, t, y):
inp_dist = inp_dist.cuda()
target_dist = target_dist.cuda()
for cc in range(self.nc_per_task):
indx = random.randint(0, self.num_exemplars - 1)
indx = random.randint(0, len(self.mem_class_x[cc + offset1]) - 1)
inp_dist[cc] = self.mem_class_x[cc + offset1][indx].clone()
target_dist[cc] = self.mem_class_y[cc +
offset1][indx].clone()
# Add distillation loss
loss += self.reg * self.kl(
self.lsm(self.net(Variable(inp_dist))
self.lsm(self.net(inp_dist)
[:, offset1: offset2]),
self.sm(Variable(target_dist[:, offset1: offset2]))) * self.nc_per_task
self.sm(target_dist[:, offset1: offset2])) * self.nc_per_task
# bprop and update
loss.backward()
self.opt.step()
Expand All @@ -166,12 +165,12 @@ def observe(self, x, t, y):
(num_classes + len(self.mem_class_x.keys())))
offset1, offset2 = self.compute_offsets(t)
for ll in range(num_classes):
lab = all_labs[ll]
lab = all_labs[ll].cuda()
indxs = (self.memy == lab).nonzero().squeeze()
cdata = self.memx.index_select(0, indxs)

# Construct exemplar set for last task
mean_feature = self.net(Variable(cdata))[
mean_feature = self.net(cdata)[
:, offset1: offset2].data.clone().mean(0)
nd = self.nc_per_task
exemplars = torch.zeros(self.num_exemplars, x.size(1))
Expand All @@ -180,14 +179,14 @@ def observe(self, x, t, y):
ntr = cdata.size(0)
# used to keep track of which examples we have already used
taken = torch.zeros(ntr)
model_output = self.net(Variable(cdata))[
model_output = self.net(cdata)[
:, offset1: offset2].data.clone()
for ee in range(self.num_exemplars):
prev = torch.zeros(1, nd)
if self.gpu:
prev = prev.cuda()
if ee > 0:
prev = self.net(Variable(exemplars[:ee]))[
prev = self.net(exemplars[:ee])[
:, offset1: offset2].data.clone().sum(0)
cost = (mean_feature.expand(ntr, nd) - (model_output
+ prev.expand(ntr, nd)) / (ee + 1)).norm(2, 1).squeeze()
Expand All @@ -203,11 +202,11 @@ def observe(self, x, t, y):
self.num_exemplars = indx.size(0)
break
# update memory with exemplars
self.mem_class_x[lab] = exemplars.clone()
self.mem_class_x[lab.item()] = exemplars.clone()

# recompute outputs for distillation purposes
for cc in self.mem_class_x.keys():
self.mem_class_y[cc] = self.net(
Variable(self.mem_class_x[cc])).data.clone()
self.mem_class_x[cc]).data.clone()
self.memx = None
self.memy = None
3 changes: 1 addition & 2 deletions model/independent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

import torch
from torch.autograd import Variable
from .common import MLP, ResNet18


Expand Down Expand Up @@ -56,7 +55,7 @@ def forward(self, x, t):
bigoutput.fill_(-10e10)
bigoutput[:, int(t * self.nc_per_task): int((t + 1) * self.nc_per_task)].copy_(
output.data)
return Variable(bigoutput)
return bigoutput
else:
return output

Expand Down
2 changes: 1 addition & 1 deletion model/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self,
reset_bias(self.i_layer[-1])

self.relu = nn.ReLU()
self.soft = nn.LogSoftmax()
self.soft = nn.LogSoftmax(dim=1)
self.loss = nn.NLLLoss()
self.optimizer = torch.optim.SGD(self.parameters(), args.lr)

Expand Down

0 comments on commit 34c6b8e

Please sign in to comment.