-
Notifications
You must be signed in to change notification settings - Fork 8
/
main_plus.py
101 lines (84 loc) · 3.86 KB
/
main_plus.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import argparse
import random
import numpy as np
from time import time
import torch
import torch.nn as nn
import torch.optim as optim
from Model_plus import GGNN_plus
from data.dataset_plus import ABoxDataset_plus
from utils.train_plus import train
from utils.test_plus import test
from data.dataloader import ABoxDataloader
parser = argparse.ArgumentParser()
parser.add_argument('--workers', type=int, help='number of data loading workers', default=0)
parser.add_argument('--batchSize', type=int, default=16, help='input batch size')
parser.add_argument('--annotation_dim', type=int, default=20, help='annotation dimension for nodes')
parser.add_argument('--state_dim', type=int, default=20, help='GGNN hidden state size')
parser.add_argument('--n_steps', type=int, default=5, help='propogation steps number of GGNN')
parser.add_argument('--niter', type=int, default=500, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.00001, help='learning rate')
parser.add_argument('--dropout_rate', type=float, default=0.0, help='probability of dropout')
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument('--use_bias', action='store_true', help='enables bias for edges', default=True)
parser.add_argument('--verbal', action='store_true', help='print training info or not', default=True)
parser.add_argument('--manualSeed', type=int, help='manual seed', default=983)
opt = parser.parse_args()
print(opt)
if opt.manualSeed is None:
opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
if opt.cuda:
torch.cuda.manual_seed_all(opt.manualSeed)
opt.dataroot = 'data/yago.base#.json'
# opt.dataroot = 'data/www.db2.json'
#opt.dataroot = 'data/www1.json'
fileName = opt.dataroot[5:]
def main(opt):
train_dataset = ABoxDataset_plus(opt.dataroot, True)
train_dataloader = ABoxDataloader(train_dataset, batch_size=opt.batchSize, \
shuffle=True, num_workers=opt.workers)
test_dataset = ABoxDataset_plus(opt.dataroot, False)
test_dataloader = ABoxDataloader(test_dataset, batch_size=opt.batchSize, \
shuffle=False, num_workers=opt.workers)
opt.n_edge_types = train_dataset.n_edge_types
opt.n_node = train_dataset.n_node
# times 2 because it's directed
net = GGNN_plus(train_dataset.n_node, train_dataset.edge_id_dic, train_dataset.type_id_dic, opt)
net.double()
# print(net)
criterion = nn.BCELoss()
# print(opt.cuda)
# print(opt.niter)
if opt.cuda:
net.cuda()
criterion.cuda()
optimizer = optim.Adam(net.parameters(), lr=opt.lr)
best_acc = 0.0 # best accuracy has been achieved
num_of_dec = 0 # number of epochs have a decline of accuracy, used for early stop
acc_last_iter = 0.0 # accuracy of the last iteration
for epoch in range(0, opt.niter):
if num_of_dec >= 15:
print("Early stop! The accuracy has been dropped for 15 iterations!")
break
train(epoch, train_dataloader, train_dataset, net, criterion, optimizer, train_dataset.edge_id_dic, \
train_dataset.type_id_dic, opt)
start = time()
correct = test(test_dataloader, test_dataset, net, criterion, train_dataset.edge_id_dic, \
train_dataset.type_id_dic, opt)
end = time()
# print(end - start)
acc = float(correct) / float(len(test_dataset))
if acc > best_acc:
best_acc = acc
print("Best accuracy by far: ", best_acc)
torch.save(net, './' + fileName + str(opt.n_steps) + '_model.pth')
if acc >= best_acc:
num_of_dec = 0
else:
num_of_dec += 1
print("The best accuracy achieved by far: ", best_acc)
if __name__ == "__main__":
main(opt)