-
Notifications
You must be signed in to change notification settings - Fork 11
/
run_rte.py
147 lines (126 loc) · 7.09 KB
/
run_rte.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import Lang as L
from rte_model import RTE
from utils import *
import torch
from sklearn.metrics import accuracy_score
import sys
import argparse
# ROOT_DIR = '/home/bass/DataDir/RTE/'
ROOT_DIR = ""
def get_arguments():
def check_boolean(args, attr_name):
assert hasattr(args, attr_name), "%s not found in parser" % (attr_name)
bool_set = set(["true", "false"])
args_value = getattr(args, attr_name)
args_value = args_value.lower()
assert args_value in bool_set, "Boolean argument required for attribute %s" % (attr_name)
args_value = False if args_value == "false" else True
setattr(args, attr_name, args_value)
return args
parser = argparse.ArgumentParser(description='Recognizing Textual Entailment')
parser.add_argument('-n_embed', action="store", default=300, dest="n_embed", type=int)
parser.add_argument('-n_dim', action="store", default=300, dest="n_dim", type=int)
parser.add_argument('-batch', action="store", default=256, dest="batch_size", type=int)
parser.add_argument('-dropout', action="store", default=0.1, dest="dropout", type=float)
parser.add_argument('-l2', action="store", default=0.0003, dest="l2", type=float)
parser.add_argument('-lr', action="store", default=0.001, dest="lr", type=float)
# Using strings as a proxy for boolean flags. Checks happen later
parser.add_argument('-last_nonlinear', action="store", default="false", dest="last_nonlinear", type=str)
parser.add_argument('-train_flag', action="store", default="true", dest="train_flag", type=str)
parser.add_argument('-continue_training', action="store", default="false", dest="continue_training", type=str)
parser.add_argument('-wbw_attn', action="store", default="false", dest="wbw_attn", type=str)
parser.add_argument('-use_pretrained', action="store", default="false", dest="use_pretrained", type=str)
parser.add_argument('-debug', action="store", default="false", dest="debug", type=str)
parser.add_argument('-h_maxlen', action="store", default=30, dest="h_maxlen", type=int)
args = parser.parse_args(sys.argv[1:])
# Checks for the boolean flags
args = check_boolean(args, 'last_nonlinear')
args = check_boolean(args, 'train_flag')
args = check_boolean(args, 'continue_training')
args = check_boolean(args, 'wbw_attn')
args = check_boolean(args, 'use_pretrained')
args = check_boolean(args, 'debug')
return args
def get_options(args):
options = {}
# MISC
options['DEBUG'] = args.debug if hasattr(args, 'debug') else False
options['CLASSES_2_IX'] = {'neutral': 1, 'contradiction': 2, 'entailment': 0}
options['VOCAB'] = ROOT_DIR + 'data/vocab.pkl'
if options['DEBUG']:
options['TRAIN_FILE'] = ROOT_DIR + 'data/tinyTrain.txt'
options['VAL_FILE'] = ROOT_DIR + 'data/tinyVal.txt'
options['TEST_FILE'] = ROOT_DIR + 'data/tinyVal.txt'
else:
options['TRAIN_FILE'] = ROOT_DIR + 'data/train.txt'
options['VAL_FILE'] = ROOT_DIR + 'data/dev.txt'
options['TEST_FILE'] = ROOT_DIR + 'data/test.txt'
# Network Properties
options['LAST_NON_LINEAR'] = args.last_nonlinear if hasattr(args, 'last_nonlinear') else False
options['USE_PRETRAINED'] = args.use_pretrained if hasattr(args, 'use_pretrained') else False
options['BATCH_SIZE'] = args.batch_size if hasattr(args, 'batch_size') else 256
options['MAX_LEN'] = args.h_maxlen if hasattr(args, 'h_maxlen') else 30
options['DROPOUT'] = args.dropout if hasattr(args, 'dropout') else 0.1
options['EMBEDDING_DIM'] = args.n_embed if hasattr(args, 'n_embed') else 300
options['HIDDEN_DIM'] = args.n_dim if hasattr(args, 'n_dim') else 300
options['L2'] = args.l2 if hasattr(args, 'l2') else 0.0003
options['LR'] = args.lr if hasattr(args, 'lr') else 0.001
options['WBW_ATTN'] = args.wbw_attn if hasattr(args, 'wbw_attn') else False
# Build the save string
if options['WBW_ATTN']:
options['SAVE_PREFIX'] = ROOT_DIR + 'models_wbw/model'
else:
options['SAVE_PREFIX'] = ROOT_DIR + 'models/model'
if options['USE_PRETRAINED']:
options['SAVE_PREFIX'] += '_USING_PRETRAINED_EMBEDDINGS'
options['SAVE_PREFIX'] += '_EMBEDDING_DIM_%d' % (options['EMBEDDING_DIM'])
options['SAVE_PREFIX'] += '_HIDDEN_DIM_%d' % (options['HIDDEN_DIM'])
options['SAVE_PREFIX'] += '_DROPOUT_%.4f' % (options['DROPOUT'])
options['SAVE_PREFIX'] += '_L2_%.4f' % (options['L2'])
options['SAVE_PREFIX'] += '_LR_%.4f' % (options['LR'])
options['SAVE_PREFIX'] += '_LAST_NON_LINEAR_%s' % (str(options['LAST_NON_LINEAR']))
options['TRAIN_FLAG'] = args.train_flag if hasattr(args, 'train_flag') else True
options['CONTINUE_TRAINING'] = args.continue_training if hasattr(args, 'continue_training') else True
return options
args = get_arguments()
options = get_options(args)
l_en = L.Lang('en')
l_en.load_file(options['VOCAB'])
def data_generator(filename, l_en):
X = []
y = []
valid_labels = set(['neutral', 'contradiction', 'entailment'])
unknown_count = 0
with open(filename) as f:
for line in f:
line = line.strip().split('\t')
if line[2] == '-':
unknown_count += 1
continue
assert line[2] in valid_labels, "Unknown label %s" % (line[2])
X.append((l_en.tokenize_sent(line[0]), l_en.tokenize_sent(line[1])))
y.append(line[2])
print 'Num Unknowns : %d' % (unknown_count)
return X, y
rte_model = RTE(l_en, options)
if options['TRAIN_FLAG']:
print "MODEL PROPERTIES:\n\tEMBEDDING_DIM : %d\n\tHIDDEN_DIM : %d" % (options['EMBEDDING_DIM'], options['HIDDEN_DIM'])
print "\tDROPOUT : %.4f\n\tL2 : %.4f\n\tLR : %.4f\n\tLAST_NON_LINEAR : %s" % (options['DROPOUT'], options['L2'], options['LR'], str(options['LAST_NON_LINEAR']))
print "\tWBW ATTN : %s\n\tUSING PRETRAINED EMBEDDINGS : %s" % (str(options['WBW_ATTN']), str(options['USE_PRETRAINED']))
print 'LOADING DATA ...'
X_train, y_train = data_generator(options['TRAIN_FILE'], l_en)
X_val, y_val = data_generator(options['VAL_FILE'], l_en)
print 'DATA LOADED:\nTRAINING SIZE : %d\nVALIDATION SIZE : %d' % (len(X_train), len(X_val))
if options['CONTINUE_TRAINING']:
best_model_file = get_best_model_file(options['SAVE_PREFIX'], model_suffix='.model')
best_model_state = torch.load(best_model_file)
rte_model.load_state_dict(best_model_state)
rte_model.fit(X_train, y_train, X_val, y_val, n_epochs=200)
else:
best_model_file = get_best_model_file(options['SAVE_PREFIX'], model_suffix='.model')
best_model_state = torch.load(best_model_file)
rte_model.load_state_dict(best_model_state)
X_test, y_test = data_generator(options['TEST_FILE'], l_en)
preds_test = rte_model.predict(X_test, options['BATCH_SIZE'], probs=False)
test_acc = accuracy_score([options['CLASSES_2_IX'][w] for w in y_test], preds_test)
print "TEST ACCURACY FROM BEST MODEL : %.4f" % (test_acc)