Skip to content

Commit

Permalink
Fix generation parameters in API models (#181)
Browse files Browse the repository at this point in the history
* fix gen_params

* remove  from

---------

Co-authored-by: wangzy <[email protected]>
  • Loading branch information
braisedpork1964 and wangzy authored Apr 12, 2024
1 parent e43646f commit b8cf464
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
2 changes: 1 addition & 1 deletion lagent/llms/base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def __init__(self,
*,
max_new_tokens: int = 512,
top_p: float = 0.8,
top_k: float = None,
top_k: int = 40,
temperature: float = 0.8,
repetition_penalty: float = 0.0,
stop_words: Union[List[str], str] = None):
Expand Down
2 changes: 1 addition & 1 deletion lagent/llms/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __init__(self,
*,
max_new_tokens: int = 512,
top_p: float = 0.8,
top_k: float = None,
top_k: float = 40,
temperature: float = 0.8,
repetition_penalty: float = 1.0,
stop_words: Union[List[str], str] = None):
Expand Down
14 changes: 11 additions & 3 deletions lagent/llms/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import time
import warnings
from concurrent.futures import ThreadPoolExecutor
from logging import getLogger
from threading import Lock
Expand All @@ -10,6 +11,8 @@

from .base_api import BaseAPIModel

warnings.simplefilter('default')

OPENAI_API_BASE = 'https://api.openai.com/v1/chat/completions'


Expand Down Expand Up @@ -54,6 +57,10 @@ def __init__(self,
],
openai_api_base: str = OPENAI_API_BASE,
**gen_params):
if 'top_k' in gen_params:
warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.',
DeprecationWarning)
gen_params.pop('top_k')
super().__init__(
model_type=model_type,
meta_template=meta_template,
Expand Down Expand Up @@ -170,14 +177,15 @@ def _chat(self, messages: List[dict], **gen_params) -> str:
header['OpenAI-Organization'] = self.orgs[self.org_ctr]

try:
gen_params_new = gen_params.copy()
data = dict(
model=self.model_type,
messages=messages,
max_tokens=max_tokens,
n=1,
stop=gen_params.pop('stop_words'),
frequency_penalty=gen_params.pop('repetition_penalty'),
**gen_params,
stop=gen_params_new.pop('stop_words'),
frequency_penalty=gen_params_new.pop('repetition_penalty'),
**gen_params_new,
)
raw_response = requests.post(
self.url, headers=header, data=json.dumps(data))
Expand Down

0 comments on commit b8cf464

Please sign in to comment.