-
Notifications
You must be signed in to change notification settings - Fork 3
/
utilities.py
213 lines (188 loc) · 8.32 KB
/
utilities.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import os
import re
import sys
from contextlib import contextmanager
from typing import List
from nltk.corpus import wordnet as wn
from pywsd.lesk import simple_lesk
from pywsd.utils import lemmatize
from colorama import Style, Fore
import requests
from tqdm import tqdm
from dotenv import load_dotenv
from llm import wsd
tqdm.pandas()
load_dotenv()
@contextmanager
def silence_output(stdout=True, stderr=True):
# Save the original stdout and stderr
original_stdout = sys.stdout
original_stderr = sys.stderr
try:
# Redirect stdout and stderr to a null device
null_device = open(os.devnull, 'w')
if stdout:
sys.stdout = null_device
if stderr:
sys.stderr = null_device
yield
finally:
# Restore stdout and stderr
sys.stdout = original_stdout
sys.stderr = original_stderr
def get_all_defs(word: str, dictionary='wordnet'):
word = lemmatize(word)
# print(f'Getting definitions for {word} in {dictionary}...')
if dictionary == 'wordnet':
word = word.replace(' ', '_') # wn uses _ to combine compound words
synsets = wn.synsets(word)
return [get_wn_def(synset) for synset in synsets]
elif dictionary == 'merriam-webster':
def get_defs_from_def_obj(def_obj) -> List[str]:
definitions = []
for sense_obj in def_obj['sseq']:
if sense_obj[0][0] == 'sense': # has defining text (dt)
defining_text = sense_obj[0][1]['dt'][0]
if defining_text[0] == 'text': # has the actual definition
definition = defining_text[1]
# remove formatting strings
definition = definition.replace("{bc}", "", 1) # remove first "{bc}"
definition = definition.replace(" {bc}", "; ") # replace remaining "{bc}" with ";"
cr_pattern = r'\{(?:dx|dx_def|dx_ety|ma)\}.*?\{\/(?:dx|dx_def|dx_ety|ma)\}'
definition = re.sub(cr_pattern, '', definition) # remove cross references
definition = re.sub(r'\{.*?\}', '', definition) # remove all other formatting strings
definition = definition.rstrip()
definitions.append(f'{pos} {definition}')
return definitions
try:
pos_map = {
'noun': 'n.',
'verb': 'v.',
'adjective': 'adj.',
'adverb': 'adv.',
'preposition': 'prep.',
'pronoun': 'pron.',
'conjunction': 'conj.',
'interjection': 'interj.',
'determiner': 'det.',
'article': 'art.',
}
definitions = []
entries = call_mw_api(word)
for entry in entries:
if type(entry) != dict or 'meta' not in entry: # not an entry
continue
if 'fl' not in entry or entry['fl'] not in pos_map: # no pos tag
continue
if word.lower() not in entry['meta']['stems']:
continue
pos = pos_map[entry['fl']]
if 'def' in entry:
for def_obj in entry['def']:
definitions.extend(get_defs_from_def_obj(def_obj))
elif 'dros' in entry: # e.g. "Richter scale" has no "def" but has "dros"
for defined_run_ons in entry['dros']:
for def_obj in defined_run_ons['def']:
definitions.extend(get_defs_from_def_obj(def_obj))
# TODO save usage info in "uns" and potentially example sentences in "vis"
except Exception as e:
print(f'In get_all_defs: {e}')
definitions = []
if not definitions:
definitions = get_all_defs(word, dictionary='wordnet') # fallback to wordnet
return definitions
else:
raise ValueError(f'Unknown dictionary: {dictionary}')
def call_mw_api(word: str):
'''
Calls Merriam-Webster API and returns a list of entries
:param word: word to look up
:return: a list of entries
Note:
The API lemmatizes the word, so the returned entries may not contain the exact word.
It's recommended to lemmatize the word before calling this function.
# TODO: implement a rate limiter # limit utilization of Hits per minute: 30/60 sec
'''
app_key = os.getenv('MERRIAM_WEBSTER_API_KEY')
if app_key is None:
raise ValueError('MERRIAM_WEBSTER_API_KEY is not set in .env')
url = f'https://dictionaryapi.com/api/v3/references/learners/json/{word.lower()}?key={app_key}'
response = requests.get(url)
if not response.ok:
response.raise_for_status()
return response.json()
def get_wn_def(synset):
if synset is None:
return ''
lexname_map = {
'noun': 'n.',
'verb': 'v.',
'adj': 'adj.',
'adv': 'adv.',
}
definition = synset.definition()
lexname = lexname_map[synset.lexname().split('.')[0]]
return f'{lexname} {definition}'
def predict_def(sent: str, word: str, mode: str, definitions: List[str] = []) -> str:
if mode == 'simple_lesk':
try:
word = word.replace(' ', '_') # wn uses _ to combine compound words
synset = simple_lesk(sent, word)
return get_wn_def(synset)
except Exception as e:
print(f'In predict_def: {e}')
return ''
elif mode == 'chatgpt_wsd':
if not definitions:
return ''
idx = wsd(sent, word, definitions, model='chatgpt') if len(definitions) > 1 else 0
if idx == -1 or idx >= len(definitions):
return '' # TODO consider using chatgpt_generation as fallback
return definitions[idx]
elif mode == 'chatgpt_generation':
raise NotImplementedError
else:
raise ValueError(f'Unknown mode: {mode}')
def get_def_manual(vocabs, mode='simple_lesk', dictionary='wordnet', report_incorrect=False):
# TODO support different modes
correct_count = incorrect_count = 0
for vocab in vocabs.itertuples():
definitions = get_all_defs(vocab.word, dictionary)
if len(definitions) == 1:
vocabs.loc[vocab.Index, 'definition'] = definitions[0]
elif len(definitions) == 0:
vocabs.loc[vocab.Index, 'definition'] = ''
else:
word_r = f'{Fore.RED}{vocab.word}{Style.RESET_ALL}'
print(f'What does {word_r} mean in this context?\n"{vocab.usage.replace(vocab.word, word_r)}"')
predicted_definition = predict_def(vocab.usage, vocab.stem, mode=mode, definitions=definitions)
# print all definitions
correct_idx = None
for i, definition in enumerate(definitions):
if definition == predicted_definition:
correct_idx = i
print(Fore.GREEN, end='')
print(f'{str(i).ljust(4)}{definition}{Style.RESET_ALL}')
user_input = input(f'({vocab.Index + 1}/{len(vocabs)}) Type number: ')
idx = int(user_input) if user_input else -1 # TODO handle invalid input that can't be converted to int
if user_input == '' or idx == correct_idx: # choose predicted definition
correct_count += 1
vocabs.loc[vocab.Index, 'definition'] = predicted_definition
else:
incorrect_count += 1
vocabs.loc[vocab.Index, 'definition'] = definitions[idx]
if report_incorrect:
total_wsd_count = incorrect_count + correct_count
print(f'Of all {len(vocabs)} vocabs, {total_wsd_count} requires WSD.')
if total_wsd_count > 0:
accuracy = round(correct_count / total_wsd_count, 4) * 100
print(f'WSD accuracy is {accuracy}% ({correct_count}/{total_wsd_count})')
return vocabs
def get_def_auto(vocabs, mode='simple_lesk', dictionary='wordnet'):
vocabs['definition'] = vocabs[['word', 'stem', 'usage']].progress_apply(
lambda x: predict_def(x['usage'], x['stem'], mode=mode, definitions=get_all_defs(x['word'], dictionary)), axis=1
)
vocabs['usage'] = vocabs[['word', 'usage']].apply(
lambda x: x['usage'].replace(x['word'], '<b><i>' + x['word'] + '</i></b>'), axis=1
)
return vocabs