Skip to content

Commit

Permalink
Add support for Llama 3 on Amazon SageMaker
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanholmes committed May 23, 2024
1 parent f7bf180 commit ef9e949
Show file tree
Hide file tree
Showing 10 changed files with 182 additions and 14 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- `AzureOpenAiStructureConfig` for providing Structures with all Azure OpenAI Driver configuration.
- `AzureOpenAiVisionImageQueryDriver` to support queries on images using Azure's OpenAI Vision models.
- `SageMakerJumpStartLlama3InstructPromptModelDriver` for using the Llama 3 Instruct model on SageMaker JumpStart.

### Changed
- Default the value of `azure_deployment` on all Azure Drivers to the model the Driver is using.
- Field `azure_ad_token` on all Azure Drivers is no longer serializable.
- **BREAKING**: `AmazonSageMakerPromptDriver.model` parameter, which gets passed to `SageMakerRuntime.Client.invoke_endpoint` as `EndpointName`, is now renamed to `AmazonSageMakerPromptDriver.endpoint`.
- **BREAKING** `AmazonSageMakerPromptDriver.model` parameter is now optional being passed to `SageMakerRuntime.Client.invoke_endpoint` as `InferenceComponentName` (instead of `EndpointName`).

## [0.25.1] - 2024-05-15

Expand Down
47 changes: 45 additions & 2 deletions docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,16 @@ through the [prompt_model_driver](../../reference/griptape/drivers/prompt/base_m

The [AmazonSageMakerPromptDriver](../../reference/griptape/drivers/prompt/amazon_sagemaker_prompt_driver.md) uses [Amazon SageMaker Endpoints](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html) for inference on AWS.

!!! info
For single model endpoints, the `model` parameter does not need to be specified.
For multi-model endpoints, the `model` parameter should be the inference component name.

!!! warning
Make sure that the selected prompt model driver is compatible with the selected model. Note that even the same
logical model can require different prompt model drivers depending on how it is bundled in the endpoint. For
example, the reponse format are different for `Meta-Llama-3-8B-Instruct` when deployed via
"Amazon SageMaker JumpStart" and "Hugging Face on Amazon SageMaker".

##### LLaMA

```python title="PYTEST_IGNORE"
Expand All @@ -416,7 +426,7 @@ from griptape.config import StructureConfig
agent = Agent(
config=StructureConfig(
prompt_driver=AmazonSageMakerPromptDriver(
model=os.environ["SAGEMAKER_LLAMA_ENDPOINT_NAME"],
endpoint=os.environ["SAGEMAKER_LLAMA_ENDPOINT_NAME"],
prompt_model_driver=SageMakerLlamaPromptModelDriver(),
temperature=0.75,
)
Expand All @@ -432,6 +442,38 @@ agent = Agent(
agent.run("Hello!")
```

##### Llama 3

```python title="PYTEST_IGNORE"
import os
from griptape.structures import Agent
from griptape.drivers import (
AmazonSageMakerPromptDriver,
SageMakerJumpStartLlama3InstructPromptModelDriver,
)
from griptape.rules import Rule
from griptape.config import StructureConfig

agent = Agent(
config=StructureConfig(
prompt_driver=AmazonSageMakerPromptDriver(
endpoint=os.environ["SAGEMAKER_LLAMA_3_ENDPOINT_NAME"],
model=os.environ["SAGEMAKER_LLAMA_3_INFERENCE_COMPONENT_NAME"],
prompt_model_driver=SageMakerJumpStartLlama3InstructPromptModelDriver(),
temperature=0.75,
)
),
rules=[
Rule(
value="You are a helpful, respectful and honest assistant who is also a swarthy pirate."
"You only speak like a pirate and you never break character."
)
],
)

agent.run("Hello!")
```

##### Falcon

```python title="PYTEST_IGNORE"
Expand All @@ -446,7 +488,8 @@ from griptape.config import StructureConfig
agent = Agent(
config=StructureConfig(
prompt_driver=AmazonSageMakerPromptDriver(
model=os.environ["SAGEMAKER_FALCON_ENDPOINT_NAME"],
endpoint=os.environ["SAGEMAKER_FALCON_ENDPOINT_NAME"],
model=os.environ["SAGEMAKER_FALCON_INFERENCE_COMPONENT_NAME"],
prompt_model_driver=SageMakerFalconPromptModelDriver(),
)
)
Expand Down
4 changes: 4 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@

from .prompt_model.base_prompt_model_driver import BasePromptModelDriver
from .prompt_model.sagemaker_llama_prompt_model_driver import SageMakerLlamaPromptModelDriver
from .prompt_model.sagemaker_jumpstart_llama3_instruct_prompt_model_driver import (
SageMakerJumpStartLlama3InstructPromptModelDriver,
)
from .prompt_model.sagemaker_falcon_prompt_model_driver import SageMakerFalconPromptModelDriver
from .prompt_model.bedrock_titan_prompt_model_driver import BedrockTitanPromptModelDriver
from .prompt_model.bedrock_claude_prompt_model_driver import BedrockClaudePromptModelDriver
Expand Down Expand Up @@ -157,6 +160,7 @@
"SqlDriver",
"BasePromptModelDriver",
"SageMakerLlamaPromptModelDriver",
"SageMakerJumpStartLlama3InstructPromptModelDriver",
"SageMakerFalconPromptModelDriver",
"BedrockTitanPromptModelDriver",
"BedrockClaudePromptModelDriver",
Expand Down
5 changes: 4 additions & 1 deletion griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class AmazonSageMakerPromptDriver(BaseMultiModelPromptDriver):
sagemaker_client: Any = field(
default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True), kw_only=True
)
endpoint: str = field(kw_only=True, metadata={"serializable": True})
model: str = field(default=None, kw_only=True, metadata={"serializable": True})
custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True})
stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})

Expand All @@ -32,10 +34,11 @@ def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
"parameters": self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack),
}
response = self.sagemaker_client.invoke_endpoint(
EndpointName=self.model,
EndpointName=self.endpoint,
ContentType="application/json",
Body=json.dumps(payload),
CustomAttributes=self.custom_attributes,
**({"InferenceComponentName": self.model} if self.model is not None else {}),
)

decoded_body = json.loads(response["Body"].read().decode("utf8"))
Expand Down
2 changes: 1 addition & 1 deletion griptape/drivers/prompt_model/base_prompt_model_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str | list |
def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict: ...

@abstractmethod
def process_output(self, output: list[dict] | str | bytes) -> TextArtifact: ...
def process_output(self, output: dict | list[dict] | str | bytes) -> TextArtifact: ...
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations
from attr import define, field
from griptape.artifacts import TextArtifact
from griptape.utils import PromptStack, import_optional_dependency
from griptape.drivers import BasePromptModelDriver
from griptape.tokenizers import HuggingFaceTokenizer


@define
class SageMakerJumpStartLlama3InstructPromptModelDriver(BasePromptModelDriver):
# Default context length for all Llama 3 models is 8K as per https://huggingface.co/blog/llama3
DEFAULT_MAX_TOKENS = 8000

_tokenizer: HuggingFaceTokenizer = field(default=None, kw_only=True)

@property
def tokenizer(self) -> HuggingFaceTokenizer:
if self._tokenizer is None:
self._tokenizer = HuggingFaceTokenizer(
tokenizer=import_optional_dependency("transformers").PreTrainedTokenizerFast.from_pretrained(
"meta-llama/Meta-Llama-3-8B-Instruct", model_max_length=self.DEFAULT_MAX_TOKENS
),
max_output_tokens=self.max_tokens or self.DEFAULT_MAX_TOKENS,
)
return self._tokenizer

def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str:
# This input format is specific to the Llama 3 Instruct model prompt format.
# For more details see: https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3#meta-llama-3-instruct
return "".join(
[
"<|begin_of_text|>",
*[
f"<|start_header_id|>{i.role}<|end_header_id|>\n\n{i.content}<|eot_id|>"
for i in prompt_stack.inputs
],
f"<|start_header_id|>{PromptStack.ASSISTANT_ROLE}<|end_header_id|>\n\n",
]
)

def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
prompt = self.prompt_driver.prompt_stack_to_string(prompt_stack)

return {
"max_new_tokens": self.prompt_driver.max_output_tokens(prompt),
"temperature": self.prompt_driver.temperature,
# This stop parameter is specific to the Llama 3 Instruct model prompt format.
# docs: https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3#meta-llama-3-instruct
"stop": "<|eot_id|>",
}

def process_output(self, output: dict | list[dict] | str | bytes) -> TextArtifact:
# This output format is specific to the Llama 3 Instruct models when deployed via SageMaker JumpStart.
if isinstance(output, dict):
return TextArtifact(output["generated_text"])
else:
raise ValueError("output must be an instance of 'dict'")
22 changes: 14 additions & 8 deletions tests/unit/drivers/prompt/test_amazon_sagemaker_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from griptape.artifacts import TextArtifact
from griptape.drivers import AmazonSageMakerPromptDriver, SageMakerLlamaPromptModelDriver
from griptape.tokenizers import HuggingFaceTokenizer, OpenAiTokenizer
from griptape.utils import PromptStack
from io import BytesIO
from unittest.mock import Mock
import json
Expand All @@ -22,17 +23,19 @@ def mock_client(self, mocker):
return mocker.patch("boto3.Session").return_value.client.return_value

def test_init(self):
assert AmazonSageMakerPromptDriver(model="foo", prompt_model_driver=SageMakerLlamaPromptModelDriver())
assert AmazonSageMakerPromptDriver(endpoint="foo", prompt_model_driver=SageMakerLlamaPromptModelDriver())

def test_custom_tokenizer(self):
assert isinstance(
AmazonSageMakerPromptDriver(model="foo", prompt_model_driver=SageMakerLlamaPromptModelDriver()).tokenizer,
AmazonSageMakerPromptDriver(
endpoint="foo", prompt_model_driver=SageMakerLlamaPromptModelDriver()
).tokenizer,
HuggingFaceTokenizer,
)

assert isinstance(
AmazonSageMakerPromptDriver(
model="foo",
endpoint="foo",
tokenizer=OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL),
prompt_model_driver=SageMakerLlamaPromptModelDriver(),
).tokenizer,
Expand All @@ -41,8 +44,9 @@ def test_custom_tokenizer(self):

def test_try_run(self, mock_model_driver, mock_client):
# Given
driver = AmazonSageMakerPromptDriver(model="model", prompt_model_driver=mock_model_driver)
prompt_stack = "prompt-stack"
driver = AmazonSageMakerPromptDriver(endpoint="model", prompt_model_driver=mock_model_driver)
prompt_stack = PromptStack()
prompt_stack.add_user_input("prompt-stack")
response_body = "invoke-endpoint-response-body"
mock_client.invoke_endpoint.return_value = {"Body": to_streaming_body(response_body)}

Expand All @@ -53,7 +57,7 @@ def test_try_run(self, mock_model_driver, mock_client):
mock_model_driver.prompt_stack_to_model_input.assert_called_once_with(prompt_stack)
mock_model_driver.prompt_stack_to_model_params.assert_called_once_with(prompt_stack)
mock_client.invoke_endpoint.assert_called_once_with(
EndpointName=driver.model,
EndpointName=driver.endpoint,
ContentType="application/json",
Body=json.dumps(
{
Expand All @@ -68,12 +72,14 @@ def test_try_run(self, mock_model_driver, mock_client):

def test_try_run_throws_on_empty_response(self, mock_model_driver, mock_client):
# Given
driver = AmazonSageMakerPromptDriver(model="model", prompt_model_driver=mock_model_driver)
driver = AmazonSageMakerPromptDriver(endpoint="model", prompt_model_driver=mock_model_driver)
mock_client.invoke_endpoint.return_value = {"Body": to_streaming_body("")}
prompt_stack = PromptStack()
prompt_stack.add_user_input("prompt-stack")

# When
with pytest.raises(Exception) as e:
driver.try_run("prompt-stack")
driver.try_run(prompt_stack)

# Then
assert e.value.args[0] == "model response is empty"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class TestSageMakerFalconPromptModelDriver:
@pytest.fixture
def driver(self):
return AmazonSageMakerPromptDriver(
model="foo",
endpoint="endpoint-name",
session=boto3.Session(region_name="us-east-1"),
prompt_model_driver=SageMakerFalconPromptModelDriver(),
temperature=0.12345,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import boto3
import pytest
from griptape.drivers.prompt_model.sagemaker_jumpstart_llama3_instruct_prompt_model_driver import (
SageMakerJumpStartLlama3InstructPromptModelDriver,
)
from griptape.utils import PromptStack
from griptape.drivers import AmazonSageMakerPromptDriver


class TestSageMakerJumpStartLlama3InstructPromptModelDriver:
@pytest.fixture
def driver(self):
return AmazonSageMakerPromptDriver(
endpoint="endpoint-name",
model="inference-component-name",
session=boto3.Session(region_name="us-east-1"),
prompt_model_driver=SageMakerJumpStartLlama3InstructPromptModelDriver(),
temperature=0.12345,
).prompt_model_driver

@pytest.fixture
def stack(self):
stack = PromptStack()

stack.add_system_input("foo")
stack.add_user_input("bar")

return stack

def test_init(self, driver):
assert driver.prompt_driver is not None

def test_prompt_stack_to_model_input(self, driver, stack):
model_input = driver.prompt_stack_to_model_input(stack)

assert isinstance(model_input, str)
assert model_input == (
"<|begin_of_text|>"
"<|start_header_id|>system<|end_header_id|>\n\nfoo<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\nbar<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)

def test_prompt_stack_to_model_params(self, driver, stack):
assert driver.prompt_stack_to_model_params(stack)["max_new_tokens"] == 7991
assert driver.prompt_stack_to_model_params(stack)["temperature"] == 0.12345

def test_process_output(self, driver, stack):
assert driver.process_output({"generated_text": "foobar"}).value == "foobar"

def test_tokenizer_max_model_length(self, driver):
assert driver.tokenizer.tokenizer.model_max_length == 8000
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class TestSageMakerLlamaPromptModelDriver:
@pytest.fixture
def driver(self):
return AmazonSageMakerPromptDriver(
model="foo",
endpoint="endpoint-name",
session=boto3.Session(region_name="us-east-1"),
prompt_model_driver=SageMakerLlamaPromptModelDriver(),
temperature=0.12345,
Expand Down

0 comments on commit ef9e949

Please sign in to comment.