Skip to content

Commit

Permalink
Add RouteLLM Prompt Driver
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jul 16, 2024
1 parent 5c4b3a8 commit 58cd7c1
Show file tree
Hide file tree
Showing 7 changed files with 847 additions and 58 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Native function calling support to `OpenAiChatPromptDriver`, `AzureOpenAiChatPromptDriver`, `AnthropicPromptDriver`, `AmazonBedrockPromptDriver`, `GooglePromptDriver`, and `CoherePromptDriver`.
- `OllamaEmbeddingDriver` for generating embeddings with Ollama.
- `GriptapeCloudKnowledgeBaseVectorStoreDriver` to query Griptape Cloud Knowledge Bases.
- `RouteLLMPromptDriver` for using RouteLLM to route between a strong Prompt Driver and weak Prompt Driver.

### Changed

Expand Down
34 changes: 34 additions & 0 deletions docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -447,3 +447,37 @@ agent = Agent(

agent.run("What is a good lasagna recipe?")
```

### RouteLLM

!!! info
This driver requires the `drivers-prompt-routellm` [extra](../index.md#extras).

The [RouteLLMPromptDriver](../../reference/griptape/drivers/prompt/routellm_prompt_driver.md) uses [RouteLLM](https://github.com/lm-sys/RouteLLM) to route between a strong Prompt Driver and a weak Prompt Driver.

```python title="PYTEST_IGNORE"
from griptape.structures import Agent
from griptape.drivers import (
RouteLlmPromptDriver,
OpenAiChatPromptDriver,
OllamaPromptDriver,
)


agent = Agent(
prompt_driver=RouteLlmPromptDriver(
max_attempts=2,
strong_prompt_driver=OpenAiChatPromptDriver(
model="gpt-4o",
),
weak_prompt_driver=OllamaPromptDriver(
model="llama3",
),
threshold=0.11593
),
)


agent.run("Hello!") # Will be routed to the weak prompt driver
agent.run("What is the square root of 1,787,569?") # Will be routed to the strong prompt driver
```
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .prompt.google_prompt_driver import GooglePromptDriver
from .prompt.dummy_prompt_driver import DummyPromptDriver
from .prompt.ollama_prompt_driver import OllamaPromptDriver
from .prompt.routellm_prompt_driver import RouteLlmPromptDriver

from .memory.conversation.base_conversation_memory_driver import BaseConversationMemoryDriver
from .memory.conversation.local_conversation_memory_driver import LocalConversationMemoryDriver
Expand Down Expand Up @@ -122,6 +123,7 @@
"GooglePromptDriver",
"DummyPromptDriver",
"OllamaPromptDriver",
"RouteLlmPromptDriver",
"BaseConversationMemoryDriver",
"LocalConversationMemoryDriver",
"AmazonDynamoDbConversationMemoryDriver",
Expand Down
103 changes: 103 additions & 0 deletions griptape/drivers/prompt/routellm_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from attrs import Factory, define, field

from griptape.drivers import BasePromptDriver
from griptape.tokenizers.dummy_tokenizer import DummyTokenizer
from griptape.utils.import_utils import import_optional_dependency

if TYPE_CHECKING:
from collections.abc import Iterator

from routellm.controller import Controller

from griptape.common import (
DeltaMessage,
Message,
PromptStack,
)
from griptape.tokenizers import BaseTokenizer


@define
class RouteLlmPromptDriver(BasePromptDriver):
"""[RouteLlm](https://github.com/lm-sys/RouteLLM) Prompt Driver.
Attributes:
strong_prompt_driver: Prompt Driver to use when routing to the strong model.
weak_prompt_driver: Prompt Driver to use when routing to the weak model.
threshold: Cost threshold when routing between models.
router: Router to use, defaults to "mf".
client: RouteLlm Controller.
tokenizer: Tokenizer to use, defaults to DummyTokenizer. After running, it will be set to the tokenizer of the routed prompt driver.
"""

model: str = field(init=False, default=None)
strong_prompt_driver: BasePromptDriver = field(
kw_only=True,
metadata={"serializable": False},
)
weak_prompt_driver: BasePromptDriver = field(
kw_only=True,
metadata={"serializable": False},
)
router: str = field(kw_only=True, default="mf", metadata={"serializable": True})
threshold: float = field(kw_only=True, metadata={"serializable": True})
client: Controller = field(
kw_only=True,
default=Factory(
lambda self: import_optional_dependency("routellm.controller").Controller(
routers=[self.router],
strong_model=self.strong_prompt_driver.model,
weak_model=self.weak_prompt_driver.model,
),
takes_self=True,
),
metadata={"serializable": False},
)
tokenizer: BaseTokenizer = field(
init=False,
default=Factory(
lambda: DummyTokenizer(),
),
kw_only=True,
)

def try_run(self, prompt_stack: PromptStack) -> Message:
prompt_driver = self._get_prompt_driver(prompt_stack)

return prompt_driver.try_run(prompt_stack)

def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
prompt_driver = self._get_prompt_driver(prompt_stack)

return prompt_driver.try_stream(prompt_stack)

def _get_prompt_driver(self, prompt_stack: PromptStack) -> BasePromptDriver:
if prompt_stack.messages:
prompt = prompt_stack.messages[-1].value
else:
raise ValueError("Prompt stack is empty.")

if isinstance(prompt, str):
routed_model = self.client.route(
prompt=prompt,
router=self.router,
threshold=self.threshold,
)
else:
raise ValueError("Prompt must be a string.")

if routed_model == self.strong_prompt_driver.model:
prompt_driver = self.strong_prompt_driver
elif routed_model == self.weak_prompt_driver.model:
prompt_driver = self.weak_prompt_driver
else:
raise ValueError(f"Model '{routed_model}' not found.")

self.model = prompt_driver.model
self.tokenizer = prompt_driver.tokenizer

return prompt_driver
Loading

0 comments on commit 58cd7c1

Please sign in to comment.