diff --git a/outlines/models/openai.py b/outlines/models/openai.py index 7d7ac61c8..771b6db71 100644 --- a/outlines/models/openai.py +++ b/outlines/models/openai.py @@ -17,6 +17,20 @@ from openai import AsyncOpenAI +@dataclass(frozen=True) +class OpenAIUsage: + prompt_tokens: Optional[int] = 0 + completion_tokens: Optional[int] = 0 + total_tokens: Optional[int] = 0 + + def __add__(self, other): + return OpenAIUsage( + prompt_tokens=self.prompt_tokens + other.prompt_tokens, + completion_tokens=self.completion_tokens + other.completion_tokens, + total_tokens=self.total_tokens + other.total_tokens, + ) + + @dataclass(frozen=True) class OpenAIConfig: """Represents the parameters of the OpenAI API. @@ -79,6 +93,8 @@ def __init__( model_name: str, api_key: Optional[str] = None, max_retries: int = 6, + timeout: Optional[float] = None, + role: Optional[str] = None, config: Optional[OpenAIConfig] = None, ): """Create an `OpenAI` instance. @@ -93,6 +109,8 @@ def __init__( `openai.api_key`. max_retries The maximum number of retries when calls to the API fail. + role + The so-called "system prompt" config An instance of `OpenAIConfig`. Can be useful to specify some parameters that cannot be set by calling this class' methods. @@ -120,7 +138,13 @@ def __init__( else: self.config = OpenAIConfig(model=model_name) - self.client = openai.AsyncOpenAI(api_key=api_key, max_retries=max_retries) + self.client = openai.AsyncOpenAI( + api_key=api_key, max_retries=max_retries, timeout=timeout + ) + self.system_prompt = role + + self.total_usage = OpenAIUsage() + self.last_usage: Union[OpenAIUsage, None] = None def __call__( self, @@ -158,7 +182,7 @@ def __call__( ) ) if "gpt-" in self.config.model: - return generate_chat(prompt, self.client, config) + return self.generate_chat(prompt, config) def generate_choice( self, prompt: str, choices: List[str], max_tokens: Optional[int] = None @@ -210,7 +234,7 @@ def generate_choice( break config = replace(config, logit_bias=mask, max_tokens=max_tokens_left) - response = generate_chat(prompt, self.client, config) + response = self.generate_chat(prompt, config) encoded_response = tokenizer.encode(response) if encoded_response in encoded_choices_left: @@ -243,6 +267,26 @@ def generate_choice( return choice + def generate_chat( + self, prompt: Union[str, List[str]], config: OpenAIConfig + ) -> np.ndarray: + """Call the async function to generate a chat response and keeps track of usage data. + + Parameters + ---------- + prompt + A string used to prompt the model as user message + config + An instance of `OpenAIConfig`. + + """ + results, usage = generate_chat(prompt, self.system_prompt, self.client, config) + + self.last_usage = OpenAIUsage(**usage) + self.total_usage += self.last_usage + + return results + def generate_json(self): """Call the OpenAI API to generate a JSON object.""" raise NotImplementedError @@ -255,12 +299,19 @@ def __repr__(self): @cache(ignore="client") -@functools.partial(outlines.vectorize, signature="(),(),()->(s)") +@functools.partial(outlines.vectorize, signature="(),(),(),()->(s),()") async def generate_chat( - prompt: str, client: "AsyncOpenAI", config: OpenAIConfig -) -> np.ndarray: + prompt: str, + system_prompt: Union[str, None], + client: "AsyncOpenAI", + config: OpenAIConfig, +) -> Tuple[np.ndarray, Dict]: responses = await client.chat.completions.create( - messages=[{"role": "user", "content": prompt}], **asdict(config) # type: ignore + messages=( + [{"role": "system", "content": system_prompt}] if system_prompt else [] + ) + + [{"role": "user", "content": prompt}], + **asdict(config), # type: ignore ) if config.n == 1: @@ -270,7 +321,7 @@ async def generate_chat( [responses.choices[i].message.content for i in range(config.n)] ) - return results + return results, responses.usage.model_dump() openai = OpenAI @@ -292,8 +343,8 @@ def find_response_choices_intersection( choices. Say the response is of the form `[1, 2, 3, 4, 5]` and we have the choices - `[[1, 2], [1, 2, 3], [6, 7, 8]` then the function will return `[1, 2]` as the - intersection, and `[1, 2, 3]` as the choice that is left. + `[[1, 2], [1, 2, 3], [6, 7, 8]` then the function will return `[1, 2, 3]` as the + intersection, and `[[]]` as the list of choices left. Parameters ---------- @@ -305,7 +356,8 @@ def find_response_choices_intersection( Returns ------- A tuple that contains the longest intersection between the response and the - different choices, and the choices which start with this intersection. + different choices, and the choices which start with this intersection, with the + intersection removed. """ max_len_prefix = 0