-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
153 lines (130 loc) · 5.38 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import logging
import os
import time
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, List
import openai
class ChatLanguageModel(ABC):
def __init__(self,
engine: str,
device: str = "",
temperatue=0.1,
topp=0.95,
frequency_penalty=0.0,
presence_penalty=0.0):
self.device = device
self.engine = engine
self.temperature = temperatue
self.topp = topp
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.chat_memory = []
self.system_prompt = None
def set_system_prompt(self, prompt: str) -> None:
self.system_prompt = prompt
@abstractmethod
def get_template_based_responses(self,
conversation_template: List[str],
input_informations: List[Dict[str, Any]]) -> List[List[Optional[str]]]:
raise NotImplementedError
class OpenaiChatGpt(ChatLanguageModel):
def __init__(self, engine: str, device: str = "", temperatue=0.1, topp=0.95, frequency_penalty=0.0,
presence_penalty=0.0):
openai.api_key = os.environ["OPENAI_API_KEY"]
super().__init__(engine, device, temperatue, topp, frequency_penalty, presence_penalty)
def _get_chat_messages(self, context: str) -> List[Dict[str, Any]]:
if self.system_prompt is None:
return self.chat_memory + [{'role': 'user', 'content': context}]
else:
return [{'role': 'system', 'content': self.system_prompt}] \
+ self.chat_memory + [{'role': 'user', 'content': context}]
def clear_chat_memory(self) -> None:
self.chat_memory = []
def create_response(self, content: str) -> Optional[str]:
# Retry logic --- 10 times
for _ in range(10):
try:
messages = self._get_chat_messages(content)
response = self._create_response_chat(messages)
self.chat_memory.append({'role': 'user', 'content': content})
self.chat_memory.append({'role': 'assistant', 'content': response})
except openai.error.RateLimitError as e:
print(f"Reach rate limit: {e}")
time.sleep(30)
continue
except Exception as e:
print(f"Exception: {e}")
time.sleep(30)
continue
return response
return None
def _create_response_chat(self, messages: List[Dict[str, Any]]) -> Optional[str]:
response = openai.ChatCompletion.create(
model=self.engine,
messages=messages,
temperature=self.temperature,
max_tokens=256,
top_p=self.topp,
frequency_penalty=self.frequency_penalty,
presence_penalty=self.presence_penalty,
)
if response is None:
return None
return response['choices'][0]['message']['content']
def get_template_based_responses(self,
conversation_template: List[str],
input_informations: List[Dict[str, Any]]) -> List[List[Optional[str]]]:
if len(self.chat_memory) > 0:
logging.warning('Chat memory is not empty, we clear it before starting a new conversation.')
self.clear_chat_memory()
all_responses = []
for input_information in input_informations:
responses = []
for num_turn in range(len(conversation_template)):
content = conversation_template[num_turn].format(**input_information)
response = self.create_response(content)
responses.append(response)
time.sleep(1)
self.clear_chat_memory()
all_responses.append(responses)
return all_responses
class OpenaiGeneralGpt:
def __init__(self, engine: str, temperatue=0.1, topp=0.95, frequency_penalty=0.0,
presence_penalty=0.0):
self.engine = engine
self.temperature = temperatue
self.topp = topp
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.chat_memory = []
self.system_prompt = None
def _create_response_completion(self, content: str) -> Optional[str]:
response = openai.Completion.create(
engine=self.engine,
prompt=content,
temperature=self.temperature,
max_tokens=256,
top_p=self.topp,
frequency_penalty=self.frequency_penalty,
presence_penalty=self.presence_penalty,
logprobs=1,
stop=["\n\n"]
)
if response is None:
return None
return response['choices'][0]['text']
def create_response(self, content: str) -> Optional[str]:
# Retry logic --- 10 times
for _ in range(10):
try:
response = self._create_response_completion(content)
except openai.error.RateLimitError as e:
print(f"Reach rate limit: {e}")
time.sleep(30)
continue
except Exception as e:
print(f"Exception: {e}")
time.sleep(30)
continue
return response
return None