Skip to content

Commit

Permalink
Add parallel tool calls toggle (#1350)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Nov 16, 2024
1 parent e56b4f5 commit bd75c02
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `TrafilaturaWebScraperDriver.no_ssl` parameter to disable SSL verification. Defaults to `False`.
- `CsvExtractionEngine.format_header` parameter to format the header row.
- `PromptStack.from_artifact` factory method for creating a Prompt Stack with a user message from an Artifact.
- `OpenAiChatPromptDriver.parallel_tool_calls` parameter for toggling parallel tool calling. Defaults to `True`.

### Changed

Expand Down
8 changes: 7 additions & 1 deletion griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class OpenAiChatPromptDriver(BasePromptDriver):
response_format: An optional OpenAi Chat Completion response format. Currently only supports `json_object` which will enable OpenAi's JSON mode.
seed: An optional OpenAi Chat Completion seed.
ignored_exception_types: An optional tuple of exception types to ignore. Defaults to OpenAI's known exception types.
parallel_tool_calls: A flag to enable parallel tool calls. Defaults to `True`.
"""

base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
Expand All @@ -75,6 +76,7 @@ class OpenAiChatPromptDriver(BasePromptDriver):
seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": False})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
parallel_tool_calls: bool = field(default=True, kw_only=True, metadata={"serializable": True})
ignored_exception_types: tuple[type[Exception], ...] = field(
default=Factory(
lambda: (
Expand Down Expand Up @@ -147,7 +149,11 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
"user": self.user,
"seed": self.seed,
**(
{"tools": self.__to_openai_tools(prompt_stack.tools), "tool_choice": self.tool_choice}
{
"tools": self.__to_openai_tools(prompt_stack.tools),
"tool_choice": self.tool_choice,
"parallel_tool_calls": self.parallel_tool_calls,
}
if prompt_stack.tools and self.use_native_tools
else {}
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_to_dict(self, config):
"azure_endpoint": "http://localhost:8080",
"api_version": "2023-05-15",
"organization": None,
"parallel_tool_calls": True,
"response_format": None,
"seed": None,
"temperature": 0.1,
Expand Down
1 change: 1 addition & 0 deletions tests/unit/configs/drivers/test_openai_driver_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def test_to_dict(self, config):
"base_url": None,
"model": "gpt-4o",
"organization": None,
"parallel_tool_calls": True,
"response_format": None,
"seed": None,
"temperature": 0.1,
Expand Down
16 changes: 14 additions & 2 deletions tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,13 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_
temperature=driver.temperature,
user=driver.user,
messages=messages,
**{"tools": self.OPENAI_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {},
**{
"tools": self.OPENAI_TOOLS,
"tool_choice": driver.tool_choice,
"parallel_tool_calls": driver.parallel_tool_calls,
}
if use_native_tools
else {},
foo="bar",
)
assert isinstance(message.value[0], TextArtifact)
Expand Down Expand Up @@ -120,7 +126,13 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack,
user=driver.user,
stream=True,
messages=messages,
**{"tools": self.OPENAI_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {},
**{
"tools": self.OPENAI_TOOLS,
"tool_choice": driver.tool_choice,
"parallel_tool_calls": driver.parallel_tool_calls,
}
if use_native_tools
else {},
foo="bar",
)

Expand Down
16 changes: 14 additions & 2 deletions tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,13 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_
user=driver.user,
messages=messages,
seed=driver.seed,
**{"tools": self.OPENAI_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {},
**{
"tools": self.OPENAI_TOOLS,
"tool_choice": driver.tool_choice,
"parallel_tool_calls": driver.parallel_tool_calls,
}
if use_native_tools
else {},
foo="bar",
)
assert isinstance(message.value[0], TextArtifact)
Expand Down Expand Up @@ -461,7 +467,13 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack,
messages=messages,
seed=driver.seed,
stream_options={"include_usage": True},
**{"tools": self.OPENAI_TOOLS, "tool_choice": driver.tool_choice} if use_native_tools else {},
**{
"tools": self.OPENAI_TOOLS,
"tool_choice": driver.tool_choice,
"parallel_tool_calls": driver.parallel_tool_calls,
}
if use_native_tools
else {},
foo="bar",
)

Expand Down

0 comments on commit bd75c02

Please sign in to comment.