From ec94e7a6b4bbb8d63cc44d3ac1961abd7e7c1a73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 14 Dec 2023 10:43:55 +0100 Subject: [PATCH] Create a new `AsyncOpenai` client for each request The client currently returns a `TimeOutError` after many requests. This seems to be a problem on the OpenAI side, but we provide this temporary fix so the OpenAI integration works fine in Outlines. --- outlines/models/openai.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/outlines/models/openai.py b/outlines/models/openai.py index cbbe00b00..8809177e9 100644 --- a/outlines/models/openai.py +++ b/outlines/models/openai.py @@ -126,9 +126,15 @@ def __init__( else: self.config = OpenAIConfig(model=model_name) - self.client = openai.AsyncOpenAI( - api_key=api_key, max_retries=max_retries, timeout=timeout + # This is necesssary because of an issue with the OpenAI API. + # Status updates: https://github.com/openai/openai-python/issues/769 + self.create_client = functools.partial( + openai.AsyncOpenAI, + api_key=api_key, + max_retries=max_retries, + timeout=timeout, ) + self.system_prompt = system_prompt # We count the total number of prompt and generated tokens as returned @@ -173,8 +179,9 @@ def __call__( ) ) if "gpt-" in self.config.model: + client = self.create_client() response, prompt_tokens, completion_tokens = generate_chat( - prompt, self.system_prompt, self.client, config + prompt, self.system_prompt, client, config ) self.prompt_tokens += prompt_tokens self.completion_tokens += completion_tokens @@ -232,8 +239,9 @@ def generate_choice( config = replace(config, logit_bias=mask, max_tokens=max_tokens_left) + client = self.create_client() response, prompt_tokens, completion_tokens = generate_chat( - prompt, self.system_prompt, self.client, config + prompt, self.system_prompt, client, config ) self.prompt_tokens += prompt_tokens self.completion_tokens += completion_tokens @@ -315,6 +323,8 @@ async def call_api(prompt, system_prompt, config): messages=system_message + user_message, **asdict(config), # type: ignore ) + await client.close() + return responses.model_dump() system_message = (