-
Notifications
You must be signed in to change notification settings - Fork 28
/
train.py
264 lines (237 loc) · 11.9 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import argparse
from tqdm import tqdm
from loguru import logger
import numpy as np
from scipy.stats import spearmanr
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from dataset import TrainDataset, TestDataset
from model import SimcseModel, simcse_unsup_loss, simcse_sup_loss
from transformers import BertModel, BertConfig, BertTokenizer
import os
from os.path import join
from torch.utils.tensorboard import SummaryWriter
import random
import pickle
import pandas as pd
import time
def seed_everything(seed=42):
'''
设置整个开发环境的seed
:param seed:
:param device:
:return:
'''
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# some cudnn methods can be random even after fixing the seed
# unless you tell it to be deterministic
torch.backends.cudnn.deterministic = True
def train(model, train_loader, dev_loader, optimizer, args):
logger.info("start training")
model.train()
device = args.device
best = 0
for epoch in range(args.epochs):
for batch_idx, data in enumerate(tqdm(train_loader)):
step = epoch * len(train_loader) + batch_idx
# [batch, n, seq_len] -> [batch * n, sql_len]
sql_len = data['input_ids'].shape[-1]
input_ids = data['input_ids'].view(-1, sql_len).to(device)
attention_mask = data['attention_mask'].view(-1, sql_len).to(device)
token_type_ids = data['token_type_ids'].view(-1, sql_len).to(device)
out = model(input_ids, attention_mask, token_type_ids)
if args.train_mode == 'unsupervise':
loss = simcse_unsup_loss(out, device)
else:
loss = simcse_sup_loss(out, device)
optimizer.zero_grad()
loss.backward()
optimizer.step()
step += 1
if step % args.eval_step == 0:
corrcoef = evaluate(model, dev_loader, device)
logger.info('loss:{}, corrcoef: {} in step {} epoch {}'.format(loss, corrcoef, step, epoch))
writer.add_scalar('loss', loss, step)
writer.add_scalar('corrcoef', corrcoef, step)
model.train()
if best < corrcoef:
best = corrcoef
torch.save(model.state_dict(), join(args.output_path, 'simcse.pt'))
logger.info('higher corrcoef: {} in step {} epoch {}, save model'.format(best, step, epoch))
def evaluate(model, dataloader, device):
model.eval()
sim_tensor = torch.tensor([], device=device)
label_array = np.array([])
with torch.no_grad():
for source, target, label in tqdm(dataloader):
# source [batch, 1, seq_len] -> [batch, seq_len]
source_input_ids = source.get('input_ids').squeeze(1).to(device)
source_attention_mask = source.get('attention_mask').squeeze(1).to(device)
source_token_type_ids = source.get('token_type_ids').squeeze(1).to(device)
source_pred = model(source_input_ids, source_attention_mask, source_token_type_ids)
# target [batch, 1, seq_len] -> [batch, seq_len]
target_input_ids = target.get('input_ids').squeeze(1).to(device)
target_attention_mask = target.get('attention_mask').squeeze(1).to(device)
target_token_type_ids = target.get('token_type_ids').squeeze(1).to(device)
target_pred = model(target_input_ids, target_attention_mask, target_token_type_ids)
# concat
sim = F.cosine_similarity(source_pred, target_pred, dim=-1)
sim_tensor = torch.cat((sim_tensor, sim), dim=0)
label_array = np.append(label_array, np.array(label))
# corrcoef
return spearmanr(label_array, sim_tensor.cpu().numpy()).correlation
def load_train_data_unsupervised(tokenizer, args):
"""
获取无监督训练语料
"""
logger.info('loading unsupervised train data')
output_path = os.path.dirname(args.output_path)
train_file_cache = join(output_path, 'train-unsupervise.pkl')
if os.path.exists(train_file_cache) and not args.overwrite_cache:
with open(train_file_cache, 'rb') as f:
feature_list = pickle.load(f)
logger.info("len of train data:{}".format(len(feature_list)))
return feature_list
feature_list = []
with open(args.train_file, 'r', encoding='utf8') as f:
lines = f.readlines()
# lines = lines[:100]
logger.info("len of train data:{}".format(len(lines)))
for line in tqdm(lines):
line = line.strip()
feature = tokenizer([line, line], max_length=args.max_len, truncation=True, padding='max_length', return_tensors='pt')
feature_list.append(feature)
with open(train_file_cache, 'wb') as f:
pickle.dump(feature_list, f)
return feature_list
def load_train_data_supervised(tokenizer, args):
"""
获取NLI监督训练语料
"""
logger.info('loading supervised train data')
output_path = os.path.dirname(args.output_path)
train_file_cache = join(output_path, 'train-supervised.pkl')
if os.path.exists(train_file_cache) and not args.overwrite_cache:
with open(train_file_cache, 'rb') as f:
feature_list = pickle.load(f)
logger.info("len of train data:{}".format(len(feature_list)))
return feature_list
feature_list = []
df = pd.read_csv(args.train_file, sep=',')
logger.info("len of train data:{}".format(len(df)))
rows = df.to_dict('reocrds')
# rows = rows[:10000]
for row in tqdm(rows):
sent0 = row['sent0']
sent1 = row['sent1']
hard_neg = row['hard_neg']
feature = tokenizer([sent0, sent1, hard_neg], max_length=args.max_len, truncation=True, padding='max_length', return_tensors='pt')
feature_list.append(feature)
with open(train_file_cache, 'wb') as f:
pickle.dump(feature_list, f)
return feature_list
def load_eval_data(tokenizer, args, mode):
"""
加载验证集或者测试集
"""
assert mode in ['dev', 'test'], 'mode should in ["dev", "test"]'
logger.info('loading {} data'.format(mode))
output_path = os.path.dirname(args.output_path)
eval_file_cache = join(output_path, '{}.pkl'.format(mode))
if os.path.exists(eval_file_cache) and not args.overwrite_cache:
with open(eval_file_cache, 'rb') as f:
feature_list = pickle.load(f)
logger.info("len of {} data:{}".format(mode, len(feature_list)))
return feature_list
if mode == 'dev':
eval_file = args.dev_file
else:
eval_file = args.test_file
feature_list = []
with open(eval_file, 'r', encoding='utf8') as f:
lines = f.readlines()
logger.info("len of {} data:{}".format(mode, len(lines)))
for line in tqdm(lines):
line = line.strip().split("\t")
assert len(line) == 7 or len(line) == 9
score = float(line[4])
data1 = tokenizer(line[5].strip(), max_length=args.max_len, truncation=True, padding='max_length', return_tensors='pt')
data2 = tokenizer(line[6].strip(), max_length=args.max_len, truncation=True, padding='max_length', return_tensors='pt')
feature_list.append((data1, data2, score))
with open(eval_file_cache, 'wb') as f:
pickle.dump(feature_list, f)
return feature_list
def main(args):
# 加载模型
tokenizer = BertTokenizer.from_pretrained(args.pretrain_model_path)
assert args.pooler in ['cls', "pooler", "last-avg", "first-last-avg"], \
'pooler should in ["cls", "pooler", "last-avg", "first-last-avg"]'
model = SimcseModel(pretrained_model=args.pretrain_model_path, pooling=args.pooler, dropout=args.dropout).to(
args.device)
if args.do_train:
# 加载数据集
assert args.train_mode in ['supervise', 'unsupervise'], \
"train_mode should in ['supervise', 'unsupervise']"
if args.train_mode == 'supervise':
train_data = load_train_data_supervised(tokenizer, args)
elif args.train_mode == 'unsupervise':
train_data = load_train_data_unsupervised(tokenizer, args)
train_dataset = TrainDataset(train_data, tokenizer, max_len=args.max_len)
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size_train, shuffle=True,
num_workers=args.num_workers)
dev_data = load_eval_data(tokenizer, args, 'dev')
dev_dataset = TestDataset(dev_data, tokenizer, max_len=args.max_len)
dev_dataloader = DataLoader(dev_dataset, batch_size=args.batch_size_eval, shuffle=True,
num_workers=args.num_workers)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
train(model, train_dataloader, dev_dataloader, optimizer, args)
if args.do_predict:
test_data = load_eval_data(tokenizer, args, 'test')
test_dataset = TestDataset(test_data, tokenizer, max_len=args.max_len)
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size_eval, shuffle=True,
num_workers=args.num_workers)
model.load_state_dict(torch.load(join(args.output_path, 'simcse.pt')))
model.eval()
corrcoef = evaluate(model, test_dataloader, args.device)
logger.info('testset corrcoef:{}'.format(corrcoef))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default='gpu', choices=['gpu', 'cpu'], help="gpu or cpu")
parser.add_argument("--output_path", type=str, default='output')
parser.add_argument("--lr", type=float, default=3e-5)
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch_size_train", type=int, default=4)
parser.add_argument("--batch_size_eval", type=int, default=256)
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--eval_step", type=int, default=100, help="every eval_step to evaluate model")
parser.add_argument("--max_len", type=int, default=64, help="max length of input")
parser.add_argument("--seed", type=int, default=42, help="random seed")
parser.add_argument("--train_file", type=str, default="data/nli_for_simcse.csv")
parser.add_argument("--dev_file", type=str, default="data/stsbenchmark/sts-dev.csv")
parser.add_argument("--test_file", type=str, default="data/stsbenchmark/sts-test.csv")
parser.add_argument("--pretrain_model_path", type=str,
default="pretrain_model/bert-base-uncased")
parser.add_argument("--pooler", type=str, choices=['cls', "pooler", "last-avg", "first-last-avg"],
default='cls', help='pooler to use')
parser.add_argument("--train_mode", type=str, default='unsupervise', choices=['unsupervise', 'supervise'], help="unsupervise or supervise")
parser.add_argument("--overwrite_cache", action='store_true', default=False, help="overwrite cache")
parser.add_argument("--do_train", action='store_true', default=True)
parser.add_argument("--do_predict", action='store_true', default=True)
args = parser.parse_args()
seed_everything(args.seed)
args.device = torch.device("cuda:0" if torch.cuda.is_available() and args.device == 'gpu' else "cpu")
args.output_path = join(args.output_path, args.train_mode, 'bsz-{}-lr-{}-dropout-{}'.format(args.batch_size_train, args.lr, args.dropout))
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
cur_time = time.strftime("%Y%m%d%H%M%S", time.localtime())
logger.add(join(args.output_path, 'train-{}.log'.format(cur_time)))
logger.info(args)
writer = SummaryWriter(args.output_path)
main(args)