-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_bert_walk.py
168 lines (144 loc) · 6.33 KB
/
train_bert_walk.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
from functools import partial
import torch
import torch.nn as nn
import time
from modeling.models import TransformerModel
from modeling.data import bert_walk_collate, BertWalkDataset
from modeling.tokenizer import bert_walk_tokenizer
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from utils.commons import *
import random
import argparse
random.seed(1984)
torch.manual_seed(1984)
device = "cuda" if torch.cuda.is_available() else "cpu"
def parse_args():
parser = argparse.ArgumentParser(description="Train the BBERTwalk model on MLM task.")
parser.add_argument("--batch_size", type=int, default=64, help="Size of batch.")
parser.add_argument("--emsize", type=int, default=128, help="Dim of embbeding.")
parser.add_argument("--nhid", type=int, default=200, help="Num of hidden dim.")
parser.add_argument("--nlayers", type=int, default=4, help="Num of transformer layers.")
parser.add_argument("--nhead", type=int, default=4, help="Num of attention heads in transformer.")
parser.add_argument("--dropout", type=float, default=0.0, help="dropout.")
parser.add_argument("--learning_rate", type=float, default=0.0005, help="Learning rate.")
parser.add_argument("--epochs", type=int, default=500, help="Num of training epochs.")
parser.add_argument("--K", type=int, default=1, help="Num of propagation iterations.")
parser.add_argument("--alpha", type=float, default=0.1, help="Reset factor RWR.")
parser.add_argument("--mask_rate", type=float, default=0.2, help="masking rate MLM.")
parser.add_argument("--p", type=float, default=1, help="Return hyperparameter. Default is 1.")
parser.add_argument("--q", type=float, default=1, help="Inout hyperparameter. Default is 1.")
parser.add_argument("--walk_length", type=int, default=10, help="Length of random walk.")
parser.add_argument("--num_walks", type=int, default=10, help="Num of walks from each node.")
parser.add_argument("--organism", type=str, default='human', help="Type of organism.")
parser.add_argument('--input_graphs', nargs='*')
return parser.parse_args()
def train_bert_walk(model, dataloader, model_params, tokenizer):
lr = model_params["learning_rate"]
mask_token_id = tokenizer.get_vocab()["[MASK]"]
model.train()
criterion = nn.CrossEntropyLoss(ignore_index=-100)
optim = torch.optim.Adam(model.parameters(), lr=lr)
best_loss = float("inf")
best_model = None
for epoch in range(model_params["epochs"]):
total_loss = 0
epoch_start_time = time.time()
model.all_prop_emb()
for b, batch in enumerate(dataloader):
optim.zero_grad()
input = batch["input"].clone()
labels = batch["input"].clone()
src_mask = batch["src_mask"]
src_key_padding_mask = batch["src_key_padding_mask"]
rand_mask = ~batch["input"].bool()
for i, row in enumerate(
torch.randint(
1,
batch["input"].shape[1],
(batch["input"].shape[0], int(batch["input"].shape[1] * model_params["mask_rate"])),
)
):
rand_mask[i, row] = True
mask_idx = (rand_mask.flatten() == True).nonzero().view(-1)
input = input.flatten()
input[mask_idx] = mask_token_id
input = input.view(batch["input"].size())
labels[input != mask_token_id] = -100
out = model(input.to(device), src_mask.to(device), src_key_padding_mask.to(device))
loss = criterion(out.view(-1, model_params["ntokens"]), labels.view(-1).to(device))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
total_loss += loss.item()
elapsed = time.time() - epoch_start_time
loss_dict = {}
val_loss = total_loss / (len(dataloader))
loss_dict["total"] = val_loss
print("-" * 89)
print(f"| epoch {1+epoch:3d} | time: {elapsed:5.2f}s | " f"loss {val_loss:5.2f} | lr {lr:02.5f}")
print("-" * 89)
writer.add_scalars("Loss", loss_dict, epoch + 1)
if total_loss < best_loss:
best_loss = total_loss
best_model = model.state_dict()
return best_model
# %%
if __name__ == "__main__":
writer = SummaryWriter(flush_secs=10)
args = parse_args()
model_params = {
"batch_size": args.batch_size,
"emsize": args.emsize,
"nhid": args.nhid,
"nlayers": args.nlayers,
"nhead": args.nhead,
"dropout": args.dropout,
"learning_rate": args.learning_rate,
"epochs": args.epochs,
"K": args.K,
"alpha": args.alpha,
"mask_rate": args.mask_rate,
"q": args.q,
"p": args.p,
"walk_length": args.walk_length,
"num_walks": args.num_walks,
"weighted": 1,
"directed": 0,
"organism": args.organism,
}
# input networks
graphs = args.input_graphs
# reading networks and corpus
data, pyg_graphs = read_data(graphs, model_params)
tokenizer = bert_walk_tokenizer(data, model_params) # loading Tokenizer
# tokenizing each network nodes and copy to device
for net in pyg_graphs:
net["node_tokens"] = torch.tensor(tokenizer.encode(net["node_sequence"][0]).ids).to(device)
net.edge_index = net.edge_index.to(device)
net.edge_weight = net.edge_weight.to(device)
model_params["ntokens"] = tokenizer.get_vocab_size()
writer.add_hparams(model_params, {"hparam/loss": 1}) # Hparam logging to TB
# Building model
model = TransformerModel(model_params, pyg_graphs).to(device)
dataset = BertWalkDataset(data, tokenizer)
dataloader = DataLoader(
dataset,
batch_size=model_params["batch_size"],
collate_fn=partial(bert_walk_collate, tokenizer=tokenizer),
shuffle=True,
)
# Training Model
best_model = train_bert_walk(model, dataloader, model_params, tokenizer)
print("Finish Training")
# saving model artifacts
print("saving best model!")
torch.save(
{
"model_params": model_params,
"model_state_dict": best_model,
"networks": pyg_graphs,
"tokenizer": tokenizer,
},
f"artifacts/{writer.log_dir.split('/')[1]}_model.pt",
)