-
Notifications
You must be signed in to change notification settings - Fork 3
/
llm.py
66 lines (50 loc) · 2.03 KB
/
llm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
from typing import List
import openai
from dotenv import load_dotenv
load_dotenv()
openai.api_key = os.getenv('OPENAI_API_KEY')
def wsd(sent: str, word: str, definitions: List[str], model='chatgpt') -> int:
prompt = get_wsd_prompt(sent, word, definitions)
if model == 'chatgpt':
predicted_def_idx = int(get_completion(prompt))
else:
raise ValueError(f'Unknown model: {model}')
# if predicted_def_idx == -1:
# return 0
return predicted_def_idx
def get_wsd_prompt(sent: str, word: str, definitions: List[str]) -> str:
formatted_definitions = ''.join([f'{i}. {d}\n' for i, d in enumerate(definitions)])
prompt = f'''
Which definition of "{word}" fits in the context of the following sentence best?
Sentence: "{sent}"
Definitions of {word}:
{formatted_definitions}
Answer with the number of the definition in this format: "1".
If you think none of the definitions fit, answer with the number
of the definition that is closest to the meaning of the word.
The answer should contain a number only and nothing else.
'''
# print(prompt)
return prompt
def get_completion(prompt, model="gpt-3.5-turbo"):
messages = [{"role": "user", "content": prompt}]
response = openai.ChatCompletion.create(
model=model,
messages=messages,
temperature=0, # this is the degree of randomness of the model's output
)
message = response.choices[0].message["content"] # type: ignore
return message.strip()
if __name__ == '__main__':
from utilities import get_all_defs, predict_def
word = 'observe'
sentence = 'Flag day is observed on June 14 and commemorates the adoption of the Stars and Stripes as the official flag of the United States.'
definitions = get_all_defs(word)
predicted_def_idx = wsd(sentence, word, definitions)
print(predicted_def_idx)
if predicted_def_idx < len(definitions):
print(definitions[predicted_def_idx])
else:
print('error definition')
print(predict_def(sentence, word, mode='simple_lesk'))