Skip to content

Commit

Permalink
add system prompt to openai call and timeout
Browse files Browse the repository at this point in the history
add usage tracking
fix comment in helper function
  • Loading branch information
HerrIvan committed Nov 24, 2023
1 parent a916372 commit 41343ae
Showing 1 changed file with 63 additions and 11 deletions.
74 changes: 63 additions & 11 deletions outlines/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@
from openai import AsyncOpenAI


@dataclass(frozen=True)
class OpenAIUsage:
prompt_tokens: Optional[int] = 0
completion_tokens: Optional[int] = 0
total_tokens: Optional[int] = 0

def __add__(self, other):
return OpenAIUsage(
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
completion_tokens=self.completion_tokens + other.completion_tokens,
total_tokens=self.total_tokens + other.total_tokens,
)


@dataclass(frozen=True)
class OpenAIConfig:
"""Represents the parameters of the OpenAI API.
Expand Down Expand Up @@ -79,6 +93,8 @@ def __init__(
model_name: str,
api_key: Optional[str] = None,
max_retries: int = 6,
timeout: Optional[float] = None,
role: Optional[str] = None,
config: Optional[OpenAIConfig] = None,
):
"""Create an `OpenAI` instance.
Expand All @@ -93,6 +109,8 @@ def __init__(
`openai.api_key`.
max_retries
The maximum number of retries when calls to the API fail.
role
The so-called "system prompt"
config
An instance of `OpenAIConfig`. Can be useful to specify some
parameters that cannot be set by calling this class' methods.
Expand Down Expand Up @@ -120,7 +138,13 @@ def __init__(
else:
self.config = OpenAIConfig(model=model_name)

self.client = openai.AsyncOpenAI(api_key=api_key, max_retries=max_retries)
self.client = openai.AsyncOpenAI(
api_key=api_key, max_retries=max_retries, timeout=timeout
)
self.system_prompt = role

self.total_usage = OpenAIUsage()
self.last_usage: Union[OpenAIUsage, None] = None

def __call__(
self,
Expand Down Expand Up @@ -158,7 +182,7 @@ def __call__(
)
)
if "gpt-" in self.config.model:
return generate_chat(prompt, self.client, config)
return self.generate_chat(prompt, config)

def generate_choice(
self, prompt: str, choices: List[str], max_tokens: Optional[int] = None
Expand Down Expand Up @@ -210,7 +234,7 @@ def generate_choice(
break

config = replace(config, logit_bias=mask, max_tokens=max_tokens_left)
response = generate_chat(prompt, self.client, config)
response = self.generate_chat(prompt, config)
encoded_response = tokenizer.encode(response)

if encoded_response in encoded_choices_left:
Expand Down Expand Up @@ -243,6 +267,26 @@ def generate_choice(

return choice

def generate_chat(
self, prompt: Union[str, List[str]], config: OpenAIConfig
) -> np.ndarray:
"""Call the async function to generate a chat response and keeps track of usage data.
Parameters
----------
prompt
A string used to prompt the model as user message
config
An instance of `OpenAIConfig`.
"""
results, usage = generate_chat(prompt, self.system_prompt, self.client, config)

self.last_usage = OpenAIUsage(**usage)
self.total_usage += self.last_usage

return results

def generate_json(self):
"""Call the OpenAI API to generate a JSON object."""
raise NotImplementedError
Expand All @@ -255,12 +299,19 @@ def __repr__(self):


@cache(ignore="client")
@functools.partial(outlines.vectorize, signature="(),(),()->(s)")
@functools.partial(outlines.vectorize, signature="(),(),(),()->(s),()")
async def generate_chat(
prompt: str, client: "AsyncOpenAI", config: OpenAIConfig
) -> np.ndarray:
prompt: str,
system_prompt: Union[str, None],
client: "AsyncOpenAI",
config: OpenAIConfig,
) -> Tuple[np.ndarray, Dict]:
responses = await client.chat.completions.create(
messages=[{"role": "user", "content": prompt}], **asdict(config) # type: ignore
messages=(
[{"role": "system", "content": system_prompt}] if system_prompt else []
)
+ [{"role": "user", "content": prompt}],
**asdict(config), # type: ignore
)

if config.n == 1:
Expand All @@ -270,7 +321,7 @@ async def generate_chat(
[responses.choices[i].message.content for i in range(config.n)]
)

return results
return results, responses.usage.model_dump()


openai = OpenAI
Expand All @@ -292,8 +343,8 @@ def find_response_choices_intersection(
choices.
Say the response is of the form `[1, 2, 3, 4, 5]` and we have the choices
`[[1, 2], [1, 2, 3], [6, 7, 8]` then the function will return `[1, 2]` as the
intersection, and `[1, 2, 3]` as the choice that is left.
`[[1, 2], [1, 2, 3], [6, 7, 8]` then the function will return `[1, 2, 3]` as the
intersection, and `[[]]` as the list of choices left.
Parameters
----------
Expand All @@ -305,7 +356,8 @@ def find_response_choices_intersection(
Returns
-------
A tuple that contains the longest intersection between the response and the
different choices, and the choices which start with this intersection.
different choices, and the choices which start with this intersection, with the
intersection removed.
"""
max_len_prefix = 0
Expand Down

0 comments on commit 41343ae

Please sign in to comment.