Skip to content

Commit

Permalink
stream chat for GPTAPI (#197)
Browse files Browse the repository at this point in the history
* _stream_chat for GPTAPI

* mapping to role that openai supports

* update
  • Loading branch information
liujiangning30 authored Jul 12, 2024
1 parent c234e6b commit 139a810
Showing 1 changed file with 115 additions and 24 deletions.
139 changes: 115 additions & 24 deletions lagent/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def __init__(self,
meta_template: Optional[Dict] = [
dict(role='system', api_role='system'),
dict(role='user', api_role='user'),
dict(role='assistant', api_role='assistant')
dict(role='assistant', api_role='assistant'),
dict(role='environment', api_role='system')
],
openai_api_base: str = OPENAI_API_BASE,
proxies: Optional[Dict] = None,
Expand Down Expand Up @@ -139,14 +140,17 @@ def stream_chat(
assert isinstance(inputs, list)
if 'max_tokens' in gen_params:
raise NotImplementedError('unsupported parameter: max_tokens')
# gen_params = {**self.gen_params, **gen_params}
gen_params = self.update_gen_params(**gen_params)
gen_params['stream'] = True

resp = ''
finished = False
stop_words = gen_params.get('stop_words')
for text in self._chat(inputs, **gen_params):
if stop_words is None:
stop_words = []
# mapping to role that openai supports
messages = self.template_parser._prompt2api(inputs)
for text in self._stream_chat(messages, **gen_params):
resp += text
if not resp:
continue
Expand All @@ -171,22 +175,6 @@ def _chat(self, messages: List[dict], **gen_params) -> str:
Returns:
str: The generated string.
"""

def _stream_chat(raw_response):
for chunk in raw_response.iter_lines(
chunk_size=8192, decode_unicode=False, delimiter=b'\n'):
if chunk:
decoded = chunk.decode('utf-8')
if decoded == 'data: [DONE]':
return
if decoded[:6] == 'data: ':
decoded = decoded[6:]
response = json.loads(decoded)
choice = response['choices'][0]
if choice['finish_reason'] == 'stop':
return
yield choice['delta']['content']

assert isinstance(messages, list)
gen_params = gen_params.copy()

Expand Down Expand Up @@ -245,12 +233,115 @@ def _stream_chat(raw_response):
headers=header,
data=json.dumps(data),
proxies=self.proxies)
response = raw_response.json()
return response['choices'][0]['message']['content'].strip()
except requests.ConnectionError:
print('Got connection error, retrying...')
continue
except requests.JSONDecodeError:
print('JsonDecode error, got', str(raw_response.content))
continue
except KeyError:
if 'error' in response:
if response['error']['code'] == 'rate_limit_exceeded':
time.sleep(1)
continue
elif response['error']['code'] == 'insufficient_quota':
self.invalid_keys.add(key)
self.logger.warn(f'insufficient_quota key: {key}')
continue

print('Find error message in response: ',
str(response['error']))
except Exception as error:
print(str(error))
max_num_retries += 1

raise RuntimeError('Calling OpenAI failed after retrying for '
f'{max_num_retries} times. Check the logs for '
'details.')

def _stream_chat(self, messages: List[dict], **gen_params) -> str:
"""Generate completion from a list of templates.
Args:
messages (List[dict]): a list of prompt dictionaries
gen_params: additional generation configuration
Returns:
str: The generated string.
"""

def streaming(raw_response):
for chunk in raw_response.iter_lines(
chunk_size=8192, decode_unicode=False, delimiter=b'\n'):
if chunk:
decoded = chunk.decode('utf-8')
if decoded == 'data: [DONE]':
return
if decoded[:6] == 'data: ':
decoded = decoded[6:]
response = json.loads(decoded)
choice = response['choices'][0]
if choice['finish_reason'] == 'stop':
return
yield choice['delta']['content']

assert isinstance(messages, list)
gen_params = gen_params.copy()

if data.get('stream', False):
return _stream_chat(raw_response)
else:
response = raw_response.json()
return response['choices'][0]['message']['content'].strip()
# Hold out 100 tokens due to potential errors in tiktoken calculation
max_tokens = min(gen_params.pop('max_new_tokens'), 4096)
if max_tokens <= 0:
return ''

max_num_retries = 0
while max_num_retries < self.retry:
if len(self.invalid_keys) == len(self.keys):
raise RuntimeError('All keys have insufficient quota.')

# find the next valid key
while True:
self.key_ctr += 1
if self.key_ctr == len(self.keys):
self.key_ctr = 0

if self.keys[self.key_ctr] not in self.invalid_keys:
break

key = self.keys[self.key_ctr]

header = {
'Authorization': f'Bearer {key}',
'content-type': 'application/json',
}

if self.orgs:
self.org_ctr += 1
if self.org_ctr == len(self.orgs):
self.org_ctr = 0
header['OpenAI-Organization'] = self.orgs[self.org_ctr]

response = dict()
try:
gen_params_new = gen_params.copy()
data = dict(
model=self.model_type,
messages=messages,
max_tokens=max_tokens,
n=1,
stop=gen_params_new.pop('stop_words'),
frequency_penalty=gen_params_new.pop('repetition_penalty'),
**gen_params_new,
)
if self.json_mode:
data['response_format'] = {'type': 'json_object'}
raw_response = requests.post(
self.url,
headers=header,
data=json.dumps(data),
proxies=self.proxies)
return streaming(raw_response)
except requests.ConnectionError:
print('Got connection error, retrying...')
continue
Expand Down

0 comments on commit 139a810

Please sign in to comment.