Skip to content

Commit

Permalink
Fix hugging face prompt pipeline driver + docs
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanholmes committed May 21, 2024
1 parent 9db5c8d commit 8dba0d4
Show file tree
Hide file tree
Showing 4 changed files with 326 additions and 8 deletions.
9 changes: 4 additions & 5 deletions docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,9 @@ agent.run("Write the code for a snake game.")
### Hugging Face Pipeline

!!! info
This driver requires the `drivers-prompt-huggingface` [extra](../index.md#extras).
This driver requires the `drivers-prompt-huggingface-pipeline` [extra](../index.md#extras).

The [HuggingFaceHubPromptDriver](../../reference/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.md) uses [Hugging Face Pipelines](https://huggingface.co/docs/transformers/main_classes/pipelines) for inference locally. It supports models with the following tasks:
The [HuggingFacePipelinePromptDriver](../../reference/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.md) uses [Hugging Face Pipelines](https://huggingface.co/docs/transformers/main_classes/pipelines) for inference locally. It supports models with the following tasks:

- text2text-generation
- text-generation
Expand All @@ -332,7 +332,7 @@ The [HuggingFaceHubPromptDriver](../../reference/griptape/drivers/prompt/hugging
```python
import os
from griptape.structures import Agent
from griptape.drivers import HuggingFaceHubPromptDriver
from griptape.drivers import HuggingFacePipelinePromptDriver
from griptape.rules import Rule, Ruleset
from griptape.utils import PromptStack
from griptape.config import StructureConfig
Expand All @@ -357,9 +357,8 @@ def prompt_stack_to_string_converter(prompt_stack: PromptStack) -> str:

agent = Agent(
config=StructureConfig(
prompt_driver=HuggingFaceHubPromptDriver(
prompt_driver=HuggingFacePipelinePromptDriver(
model="tiiuae/falcon-7b-instruct",
api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"],
prompt_stack_to_string=prompt_stack_to_string_converter,
)
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class HuggingFacePipelinePromptDriver(BasePromptDriver):
SUPPORTED_TASKS = ["text2text-generation", "text-generation"]
DEFAULT_PARAMS = {"return_full_text": False, "num_return_sequences": 1}

max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True})
model: str = field(kw_only=True, metadata={"serializable": True})
params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})
tokenizer: HuggingFaceTokenizer = field(
Expand Down
Loading

0 comments on commit 8dba0d4

Please sign in to comment.