-
Notifications
You must be signed in to change notification settings - Fork 1
/
ablation_study.py
109 lines (83 loc) · 3.41 KB
/
ablation_study.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import numpy as np
import torch
import torch as T
import torchvision
from torchvision import datasets, transforms
from tqdm import tqdm
from utils import get_program, tanh_scaler, ProgrammingNetwork, run_test_accuracy
def get_mnist(batch_size):
"""
This function retruns the train and test loader of mnist
dataset for a given batch_size
:param batch_size: size of the batch for data loader
:type batch_size: int
:return: train and test loader
:rtype: tuple[torch.utils.data.DataLoader]
"""
train_loader = T.utils.data.DataLoader(datasets.MNIST(
'./data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True
)
test_loader = T.utils.data.DataLoader(datasets.MNIST(
'./data', train=False, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True
)
return train_loader, test_loader
def make_programming_network(pretrained_model, input_size=224, patch_size=28, device="cpu"):
model = ProgrammingNetwork(pretrained_model, input_size, patch_size, blur_sigma=1.5, device=device)
return model
def run_test_accuracy(model, test_loader, device='cpu'):
test_accuracy = []
for i, (x, y) in enumerate(tqdm(test_loader)):
y_hat = model(x.to(device))
(y_hat.argmax(1).to("cpu") == y).float()
test_accuracy.extend((y_hat.argmax(1).to("cpu") == y).float().numpy())
return np.array(test_accuracy).mean()
def prune_program(model, path, band_width, band_value=0, batch_size=16, location="cpu", device='cpu'):
#profondeur, largeur, hauteur
p = torch.load(path, map_location=location)
p[:, 0: band_width, :] = p[:, -band_width:, :] = band_value
p[:, :, 0: band_width] = p[:, :, -band_width:] = band_value
new_p = torch.autograd.Variable(torch.tensor(p), requires_grad=True) #torch.tensor(..).float()
model.p = new_p
model.p.requires_grad = False #eval mode
model = to_device(model, device=device)
_, test_loader = get_mnist(batch_size)
return run_test_accuracy(model, test_loader, device=device)
def to_device(programmingNetwork, device="cpu"):
programmingNetwork.device = device
programmingNetwork.p = T.autograd.Variable(
programmingNetwork.p.to(device),
requires_grad=True
)
programmingNetwork.mask = programmingNetwork.mask.to(device)
programmingNetwork.one = programmingNetwork.one.to(device)
programmingNetwork.model = programmingNetwork.model.to(device)
return programmingNetwork
DEVICE = "cuda:0"
PATH = "/tmp/a.pth"
pretrained_model = torchvision.models.squeezenet1_0(pretrained=True).eval()
model = make_programming_network(pretrained_model, device=DEVICE)
bands_width = list(range(1, 11)) + list(range(15, 51, 5)) + [100, 112]
band_value = 0
test_pruning_accuracy = {}
for band_width in bands_width:
test_pruning_accuracy[band_width] = prune_program(
model, PATH,
band_width,
band_value=0,
device=DEVICE
)
print("\n - accuracy for bandwidth {} : {}".format(
band_width,
test_pruning_accuracy[band_width]
))
np.save("./models/MNIST_Squeeze1_0_test_pruning_accuracy", test_pruning_accuracy)