From c9b92ea3109cbe936091d3640921dde0a82a92a9 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 9 Jul 2024 09:58:53 -0700 Subject: [PATCH] Fix azure streaming (#946) --- griptape/drivers/prompt/azure_openai_chat_prompt_driver.py | 6 +++++- griptape/drivers/prompt/openai_chat_prompt_driver.py | 5 ++--- .../drivers/prompt/test_azure_openai_chat_prompt_driver.py | 7 +------ 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py b/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py index 50e9effe6..64145b1a9 100644 --- a/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py @@ -44,6 +44,10 @@ class AzureOpenAiChatPromptDriver(OpenAiChatPromptDriver): def _base_params(self, prompt_stack: PromptStack) -> dict: params = super()._base_params(prompt_stack) # TODO: Add `seed` parameter once Azure supports it. - del params["seed"] + if "seed" in params: + del params["seed"] + # TODO: Add `stream_options` parameter once Azure supports it. + if "stream_options" in params: + del params["stream_options"] return params diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index e1e046d11..a89b4eb57 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -90,9 +90,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message: raise Exception("Completion with more than one choice is not supported yet.") def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: - result = self.client.chat.completions.create( - **self._base_params(prompt_stack), stream=True, stream_options={"include_usage": True} - ) + result = self.client.chat.completions.create(**self._base_params(prompt_stack), stream=True) for chunk in result: if chunk.usage is not None: @@ -124,6 +122,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "seed": self.seed, **({"stop": self.tokenizer.stop_sequences} if self.tokenizer.stop_sequences else {}), **({"max_tokens": self.max_tokens} if self.max_tokens is not None else {}), + **({"stream_options": {"include_usage": True}} if self.stream else {}), } if self.response_format == "json_object": diff --git a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py index 92544a74e..378ecc3da 100644 --- a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py @@ -57,12 +57,7 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, # Then mock_chat_completion_stream_create.assert_called_once_with( - model=driver.model, - temperature=driver.temperature, - user=driver.user, - stream=True, - messages=messages, - stream_options={"include_usage": True}, + model=driver.model, temperature=driver.temperature, user=driver.user, stream=True, messages=messages ) assert event.content.text == "model-output"