-
Notifications
You must be signed in to change notification settings - Fork 0
/
BERT_lime.py
71 lines (49 loc) · 1.83 KB
/
BERT_lime.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from typing import Dict, List, Union
import pandas as pd
import numpy as np
import random
import matplotlib.pyplot as plt
import dill
from pathlib import Path
from pysentimiento import create_analyzer
from lime.lime_text import LimeTextExplainer
from pysentimiento.analyzer import AnalyzerOutput
def sort_sentiment(res: AnalyzerOutput) -> np.array:
vals = [res.probas[k] for k in sentiment.id2label.values()]
return np.array(vals).reshape(1, -1)
def list_to_arr(result: List[AnalyzerOutput]) -> np.ndarray:
return np.vstack([sort_sentiment(out) for out in result])
def format_output(result: Union[List[AnalyzerOutput], AnalyzerOutput]) -> np.ndarray:
try:
return sort_sentiment(result)
except AttributeError:
return list_to_arr(result)
def dict_to_arr(dct: dict) -> np.ndarray:
n_feats = len(dct.values())
return np.array(list(dct.values())).reshape((-1, n_feats))
def predict_pos_proba(sentence: str) -> np.ndarray:
pred = sentiment.predict(sentence)
return format_output(pred)
sentiment = create_analyzer("sentiment", lang="en")
sentence = ["I'm tweeting and I'm happy!", "I'm sad"]
output = sentiment.predict(sentence)
predict_pos_proba(sentence)
labels = list(sentiment.id2label.values())
list_to_arr(output)
explainer = LimeTextExplainer(class_names=labels)
explains = explainer.explain_instance(sentence[0], predict_pos_proba, num_features=3)
# Test on real data
test_dat = pd.read_csv(
Path("./output/clean_tweets.csv"),
header=0,
skiprows=lambda i: i > 0 and random.random() > 0.01,
)
testy = test_dat.sample(1)
sentence = testy["cleantext"].tolist()[0]
top_label = np.argmax(predict_pos_proba(sentence))
explains = explainer.explain_instance(
sentence, predict_pos_proba, num_features=5, labels=[top_label]
)
explains.as_list()
fig = explains.as_pyplot_figure()
plt.show()