-
Notifications
You must be signed in to change notification settings - Fork 184
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for Llama 3 on Amazon SageMaker
- Loading branch information
1 parent
f7bf180
commit ef9e949
Showing
10 changed files
with
182 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
57 changes: 57 additions & 0 deletions
57
griptape/drivers/prompt_model/sagemaker_jumpstart_llama3_instruct_prompt_model_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from __future__ import annotations | ||
from attr import define, field | ||
from griptape.artifacts import TextArtifact | ||
from griptape.utils import PromptStack, import_optional_dependency | ||
from griptape.drivers import BasePromptModelDriver | ||
from griptape.tokenizers import HuggingFaceTokenizer | ||
|
||
|
||
@define | ||
class SageMakerJumpStartLlama3InstructPromptModelDriver(BasePromptModelDriver): | ||
# Default context length for all Llama 3 models is 8K as per https://huggingface.co/blog/llama3 | ||
DEFAULT_MAX_TOKENS = 8000 | ||
|
||
_tokenizer: HuggingFaceTokenizer = field(default=None, kw_only=True) | ||
|
||
@property | ||
def tokenizer(self) -> HuggingFaceTokenizer: | ||
if self._tokenizer is None: | ||
self._tokenizer = HuggingFaceTokenizer( | ||
tokenizer=import_optional_dependency("transformers").PreTrainedTokenizerFast.from_pretrained( | ||
"meta-llama/Meta-Llama-3-8B-Instruct", model_max_length=self.DEFAULT_MAX_TOKENS | ||
), | ||
max_output_tokens=self.max_tokens or self.DEFAULT_MAX_TOKENS, | ||
) | ||
return self._tokenizer | ||
|
||
def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str: | ||
# This input format is specific to the Llama 3 Instruct model prompt format. | ||
# For more details see: https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3#meta-llama-3-instruct | ||
return "".join( | ||
[ | ||
"<|begin_of_text|>", | ||
*[ | ||
f"<|start_header_id|>{i.role}<|end_header_id|>\n\n{i.content}<|eot_id|>" | ||
for i in prompt_stack.inputs | ||
], | ||
f"<|start_header_id|>{PromptStack.ASSISTANT_ROLE}<|end_header_id|>\n\n", | ||
] | ||
) | ||
|
||
def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict: | ||
prompt = self.prompt_driver.prompt_stack_to_string(prompt_stack) | ||
|
||
return { | ||
"max_new_tokens": self.prompt_driver.max_output_tokens(prompt), | ||
"temperature": self.prompt_driver.temperature, | ||
# This stop parameter is specific to the Llama 3 Instruct model prompt format. | ||
# docs: https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3#meta-llama-3-instruct | ||
"stop": "<|eot_id|>", | ||
} | ||
|
||
def process_output(self, output: dict | list[dict] | str | bytes) -> TextArtifact: | ||
# This output format is specific to the Llama 3 Instruct models when deployed via SageMaker JumpStart. | ||
if isinstance(output, dict): | ||
return TextArtifact(output["generated_text"]) | ||
else: | ||
raise ValueError("output must be an instance of 'dict'") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
52 changes: 52 additions & 0 deletions
52
...nit/drivers/prompt_models/test_sagemaker_jumpstart_llama3_instruct_prompt_model_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import boto3 | ||
import pytest | ||
from griptape.drivers.prompt_model.sagemaker_jumpstart_llama3_instruct_prompt_model_driver import ( | ||
SageMakerJumpStartLlama3InstructPromptModelDriver, | ||
) | ||
from griptape.utils import PromptStack | ||
from griptape.drivers import AmazonSageMakerPromptDriver | ||
|
||
|
||
class TestSageMakerJumpStartLlama3InstructPromptModelDriver: | ||
@pytest.fixture | ||
def driver(self): | ||
return AmazonSageMakerPromptDriver( | ||
endpoint="endpoint-name", | ||
model="inference-component-name", | ||
session=boto3.Session(region_name="us-east-1"), | ||
prompt_model_driver=SageMakerJumpStartLlama3InstructPromptModelDriver(), | ||
temperature=0.12345, | ||
).prompt_model_driver | ||
|
||
@pytest.fixture | ||
def stack(self): | ||
stack = PromptStack() | ||
|
||
stack.add_system_input("foo") | ||
stack.add_user_input("bar") | ||
|
||
return stack | ||
|
||
def test_init(self, driver): | ||
assert driver.prompt_driver is not None | ||
|
||
def test_prompt_stack_to_model_input(self, driver, stack): | ||
model_input = driver.prompt_stack_to_model_input(stack) | ||
|
||
assert isinstance(model_input, str) | ||
assert model_input == ( | ||
"<|begin_of_text|>" | ||
"<|start_header_id|>system<|end_header_id|>\n\nfoo<|eot_id|>" | ||
"<|start_header_id|>user<|end_header_id|>\n\nbar<|eot_id|>" | ||
"<|start_header_id|>assistant<|end_header_id|>\n\n" | ||
) | ||
|
||
def test_prompt_stack_to_model_params(self, driver, stack): | ||
assert driver.prompt_stack_to_model_params(stack)["max_new_tokens"] == 7991 | ||
assert driver.prompt_stack_to_model_params(stack)["temperature"] == 0.12345 | ||
|
||
def test_process_output(self, driver, stack): | ||
assert driver.process_output({"generated_text": "foobar"}).value == "foobar" | ||
|
||
def test_tokenizer_max_model_length(self, driver): | ||
assert driver.tokenizer.tokenizer.model_max_length == 8000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters