-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_experiment.py
174 lines (158 loc) · 5.67 KB
/
run_experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# run_experiment.py
# Run experiments for custom interpreter
# You should define the interpreter function handle below the imports. It takes the model, question and context as input and returns the sentence attribution scores as output (high score=important)
# Trivial Example:
# >> from nltk.tokenize import sent_tokenize
# >> interpreter = lambda model, question, context: list(range(len(sent_tokenize(context))))
import argparse
import torch
import pandas as pd
import qa_experimenters
import models
import os
# IMPORT YOUR CODE HERE
# "interpreter" should be defined
# +----------------------------+
# | CUSTOM CODE START |
# +----------------------------+
from nltk.tokenize import sent_tokenize
def interpreter(model, question, context):
"""
Trivial interpreter which always returns the first sentence as being the most important one.
"""
return [1] + [0] * (len(sent_tokenize(context)) - 1)
# +----------------------------+
# | CUSTOM CODE END |
# +----------------------------+
def run(dataset, device):
try:
interpreter
except NameError:
raise NotImplementedError(
"Interpreter not defined, check the top of the source code on how to use your model."
)
else:
# run for QA model
print("Running with QA model...")
experiment = (
qa_experimenters.SQuADExperimenter(
models.Model_QA(device=device), interpreter, split=10
)
if dataset == "SQuAD"
else qa_experimenters.SQuADShiftsExperimenter(
models.Model_QA(device=device), interpreter, dataset.lower()
)
)
experiment.experiment()
if not os.path.exists("results_quackie/results_QA_" + dataset):
os.makedirs("results_quackie/results_QA_" + dataset)
experiment.save(path="results_quackie/results_QA_" + dataset)
print("Done!")
# run for Classification model
print("Running with Classification model...")
experiment = (
qa_experimenters.SQuADExperimenter(
models.Model_Classification(device=device), interpreter, split=10
)
if dataset == "SQuAD"
else qa_experimenters.SQuADShiftsExperimenter(
models.Model_Classification(device=device), interpreter, dataset.lower()
)
)
experiment.experiment()
if not os.path.exists("results_quackie/results_Classification_" + dataset):
os.makedirs("results_quackie/results_Classification_" + dataset)
experiment.save(path="results_quackie/results_Classification_" + dataset)
print("Done!")
def analyze(args):
E = qa_experimenters.QAExperimenter(None, None, None)
df = df = pd.DataFrame(
columns=[
"model",
"interpreter",
"mean_iou",
"mean_hpd",
"mean_snr",
"std_iou",
"std_hpd",
"std_snr",
"fails",
"no_snr",
"info",
"dataset",
]
)
for folder in os.listdir("results_quackie"):
E.load("results_quackie/results_QA_SQuAD")
res = E.analyze(interpreter)[
[
"mean_iou",
"mean_hpd",
"mean_snr",
"std_iou",
"std_hpd",
"std_snr",
"fails",
"no_snr",
]
]
res["model"] = folder.split("_")[1]
res["dataset"] = folder.split("_")[2]
res["interpreter"] = args.name
res["info"] = args.info
df = df.append(res)
df = df[
["dataset", "model", "interpreter", "info", "mean_iou", "mean_hpd", "mean_snr"]
]
df.columns = ["dataset", "classifier", "method", "info", "IoU", "HPD", "SNR"]
print(df)
df.to_json("results.json", orient="records")
print("Saved results in results.json")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Parameters of the experiment")
parser.add_argument(
"--dataset",
type=str,
choices=["SQuAD", "NEW_WIKI", "NYT", "Reddit", "Amazon"],
help="Dataset to use, can be one of ['SQuAD', 'NEW_WIKI', 'NYT', 'Reddit', 'Amazon'], SQuAD for SQuAD dataset, others for SQuADShifts",
)
parser.add_argument(
"--run",
dest="run",
action="store_true",
help="Flag if Experiment should be run",
)
parser.add_argument(
"--analyze",
dest="analyze",
action="store_true",
help="Flag if Experiment should be analyzed (all results will be analyzed).",
)
parser.add_argument(
"--name",
type=str,
default="Custom",
help="Name of the interpreter to use in results",
)
parser.add_argument(
"--info", type=str, default="Custom", help="Content for the 'Info' column",
)
parser.add_argument(
"--no_cuda", dest="no_cuda", action="store_true", help="Flag if force cpu usage"
)
args = parser.parse_args()
if not (args.run or args.analyze):
print("Please run or analyze")
else:
device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
print("Using device {}".format(device))
if args.run:
assert args.dataset, "Dataset required for running"
run(args.dataset, device)
if args.analyze:
print("\nAnalyzing for all datasets..")
analyze(args)
print(
"\n\nThank you for choosing QUACKIE. You can submit your results via git pull request, more info here: "
)
print("https://github.com/axa-rev-research/quackie/tree/gh-pages")