-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_train.py
70 lines (54 loc) · 3.09 KB
/
main_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from train import train_valid_model, Net, VGG_net
from plot import plot_results
import torch.optim as optim
import argparse
import pickle
from data import generate_cifar_loaders
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', default='./../data', help='directory for data')
parser.add_argument('--model', default='cnn', help='experiment with either simple cnn or vgg')
parser.add_argument('--max_layer', type=int, default=2, help='possible number of layers in the CNN')
parser.add_argument('--batch_size', type=int, default=4, help='batch size')
parser.add_argument('--epochs', type=int, default=50, help='num epoch')
parser.add_argument('--early_stop', type=int, default=10, help='nb of epochs with no improvements before stopping')
parser.add_argument('--training_size', type=int, default=8000, help='nb of training examples')
opt = parser.parse_args()
model_choice = opt.model
results = []
train_loader, valid_loader = generate_cifar_loaders(opt.training_size, 0)
if model_choice == 'cnn':
max_layer = opt.max_layer
layers = [[i, j] for i in range(max_layer) for j in range(max_layer)]
print(layers)
for layer in layers:
net = Net(layers=layer)
print(net)
nb_params = sum(p.numel() for p in net.parameters())
optimizer = optim.Adam(net.parameters(), lr=0.001)
acc = train_valid_model(net, opt.epochs, optimizer, train_loader, valid_loader, verbose=True)
results.append({'layers':layer, 'num_params': nb_params, 'accuracy': acc})
elif model_choice == 'vgg':
""" Images are 3*32*32
- with 5 maxpooling, features obtained are of dimension 1
- with 4 maxpooling, features obtained are of dimension 4, etc.
- the dimension to feed to the classifier is specified by the param "n_features"
"""
architectures = {'ex_1': [16, 'M', 32, 'M', 64, 'M', 128, 'M', 256, 'M'],
'ex_2': [32, 'M', 64, 'M', 128, 'M', 256, 'M', 512, 'M'],
'ex_3': [32, 'M', 64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M'],
'ex_4': [16, 'M', 32, 'M', 64, 'M', 128, 'M'],
'ex_5': [16, 16, 'M', 32, 32, 'M', 64, 64, 'M', 128, 128, 'M', 256, 256, 'M'],
'ex_6': [16, 16, 16, 'M', 32, 32, 32, 'M', 64, 64, 64, 'M', 128, 128, 128, 'M', 256, 256, 256, 'M'],
'ex_7': [16, 16, 16, 16, 'M', 32, 32, 32, 32, 'M', 64, 64, 64, 64, 'M', 128, 128, 128, 128, 'M', 256, 256, 256, 256, 'M']
}
nb_features = {'ex_1':256, 'ex_2':512, 'ex_3':512, 'ex_4':512, 'ex_5':256, 'ex_6':256, 'ex_7':256}
for ex in architectures.keys():
net = VGG_net(architectures[ex], nb_features[ex])
print(architectures[ex])
nb_params = sum(p.numel() for p in net.parameters())
optimizer = optim.Adam(net.parameters(), lr=0.001)
acc = train_valid_model(net, opt.epochs, optimizer, train_loader, valid_loader, opt.early_stop, verbose=True)
results.append({'architecture': architectures[ex], 'num_params': nb_params, 'accuracy': acc})
with open("test.np", "wb") as fp:
pickle.dump(results, fp)
#plot_results("test.np", 'test.png')