-
Notifications
You must be signed in to change notification settings - Fork 58
/
main.py
145 lines (125 loc) · 7.08 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import re
import argparse
import importlib
import pickle
import tensorflow as tf
from config import RUN_CONFIG
from dataset import NerDataset, MultiDataset
from tools.train_utils import build_model_fn,build_mtl_model_fn
from tools.utils import clear_model, build_estimator
from tools.infer_utils import EXPORT_DIR, get_receiver
def singletask_train(args):
model_name = args.rename if args.rename else args.model_name
model_dir = './checkpoint/ner_{}_{}'.format(args.data, model_name)
data_dir = './data/{}'.format(args.data)
if args.clear_model:
clear_model(model_dir)
# Init dataset and pass parameter to train_params
TRAIN_PARAMS = getattr(importlib.import_module('model.{}'.format(args.model_name)), 'TRAIN_PARAMS')
input_pipe = NerDataset(data_dir, TRAIN_PARAMS['batch_size'], TRAIN_PARAMS['epoch_size'], model_name)
TRAIN_PARAMS.update(input_pipe.params) # add label_size, max_seq_len, num_train_steps into train_params
print('='*10+'TRAIN PARAMS'+'='*10)
print(dict([(i,j )for i,j in TRAIN_PARAMS.items() if ('emb' not in i) and ('vocab' not in i )]))
print('='*10+'RUN PARAMS'+'='*10)
print(RUN_CONFIG)
# Init estimator'
model_fn = build_model_fn(args.model_name)
estimator = build_estimator(TRAIN_PARAMS, model_dir, model_fn, args.gpu, RUN_CONFIG)
if args.export_only:
# only export model when ckpt already exits
print('Exporting Model for serving_model at {}'.format(EXPORT_DIR.format(model_name)))
estimator._export_to_tpu = False
estimator.export_saved_model(EXPORT_DIR.format(model_name),
get_receiver(TRAIN_PARAMS['max_seq_len'], input_pipe.surfix))
else:
# Run Train & Evaluate
early_stopping_hook = tf.estimator.experimental.stop_if_no_decrease_hook(
estimator, metric_name='loss',
max_steps_without_decrease=int(TRAIN_PARAMS['step_per_epoch'] * TRAIN_PARAMS['early_stop_ratio'])
)
train_spec = tf.estimator.TrainSpec(input_pipe.build_input_fn('train'), hooks=[early_stopping_hook])
eval_spec = tf.estimator.EvalSpec(input_pipe.build_input_fn('valid', is_predict=True), throttle_secs=60)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
# Do prediction when train finished
prediction = estimator.predict(input_fn=input_pipe.build_input_fn('predict', is_predict=True))
prediction = [i for i in prediction]
with open('./data/{}/{}_predict.pkl'.format(args.data, model_name), 'wb') as f:
pickle.dump(prediction, f)
print('Exporting Model for serving_model at {}'.format(EXPORT_DIR.format(model_name)))
estimator._export_to_tpu = False
estimator.export_saved_model(EXPORT_DIR.format(model_name),
get_receiver(TRAIN_PARAMS['max_seq_len'], input_pipe.surfix))
def multitask_train(args):
"""
Train Multitask or adversarial task. Must provide more than 1 dataset, and corresponding mtl/adv model
Only used for train and evaluate, for prediction use above predicion
"""
model_name = args.rename if args.rename else args.model_name
model_dir = './checkpoint/ner_{}_{}'.format('_'.join(args.data.split(',')), model_name)
if args.clear_model:
clear_model(model_dir)
data_dir = './data'
data_list = args.data.split(',')
# Init dataset and pass parameter to train_params
TRAIN_PARAMS = getattr(importlib.import_module('model.{}'.format(args.model_name)), 'TRAIN_PARAMS')
input_pipe = MultiDataset(data_dir, data_list, TRAIN_PARAMS['batch_size'], TRAIN_PARAMS['epoch_size'], model_name)
TRAIN_PARAMS.update(input_pipe.params) # add label_size, max_seq_len, num_train_steps into train_params
print('='*10+'TRAIN PARAMS'+'='*10)
print(TRAIN_PARAMS)
print('='*10+'RUN PARAMS'+'='*10)
print(RUN_CONFIG)
# Init estimator
model_fn = build_mtl_model_fn(args.model_name)
estimator = build_estimator(TRAIN_PARAMS, model_dir, model_fn, args.gpu, RUN_CONFIG)
if args.export_only:
# only export model when ckpt already exits
print('Exporting Model for serving_model at {}'.format(EXPORT_DIR.format(model_name)))
estimator._export_to_tpu = False
estimator.export_saved_model(EXPORT_DIR.format(model_name),
get_receiver(TRAIN_PARAMS['max_seq_len'], None, True))
else:
# Run Train & Evaluate
early_stopping_hook = tf.estimator.experimental.stop_if_no_decrease_hook(
estimator, metric_name='loss',
max_steps_without_decrease=int(TRAIN_PARAMS['step_per_epoch'] * TRAIN_PARAMS['early_stop_ratio'])
)
train_spec = tf.estimator.TrainSpec(input_pipe.build_input_fn('train'), hooks=[early_stopping_hook])
eval_spec = tf.estimator.EvalSpec(input_pipe.build_input_fn('valid'), throttle_secs=60)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
# Do prediction when train finished
for data in data_list:
print('Prediction for {}'.format(data))
prediction = estimator.predict(input_pipe.build_predict_fn(data))
prediction = [pred for pred in prediction]
# for mutli-task, prediction file name is {model_name}_{task_list}_predict
with open('./data/{}/{}_{}_predict.pkl'.format(data, model_name,
'_'.join(args.data.split(','))), 'wb') as f:
pickle.dump(prediction, f)
print('Exporting Model for serving_model at {}'.format(EXPORT_DIR.format(model_name)))
estimator._export_to_tpu = False
estimator.export_saved_model(EXPORT_DIR.format(model_name),
get_receiver(TRAIN_PARAMS['max_seq_len'], None, True))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, help='model_name[bert_bilstm_crf, bert_crf, bert_ce]',
required=True)
parser.add_argument('--clear_model', type=int, help='Whether to clear existing model',
required=False, default=0)
parser.add_argument('--data', type=str, help='which data to use[msra, cluener, people_daily]',
required=False, default='msra')
parser.add_argument('--gpu', type=int, help='Whether to enable gpu',
required=False, default=0)
parser.add_argument('--device', type=int, help='which gpu to use',
required=False, default=-1)
parser.add_argument('--rename', type=str, help='Allow rename model with special parameter',
required=False, default='')
parser.add_argument('--export_only', type=int, help='Export Model without training when ckpt exists',
required=False, default=0)
args = parser.parse_args()
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(args.device)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # disable debugging logging
if len(args.data.split(','))>1:
multitask_train(args)
else:
singletask_train(args)