Skip to content

Commit

Permalink
Add decoder for calling Anthropic models via Amazon Bedrock (#151)
Browse files Browse the repository at this point in the history
* added configuration and decoder for Bedrock Anthropic

* clean up comments

* added test for bedrock anthropic

* added init file for bedrock_anthropic_completions

* added Claude Bedrock outputs
  • Loading branch information
billcai authored Oct 29, 2023
1 parent bb19c59 commit 6e6d11f
Show file tree
Hide file tree
Showing 10 changed files with 14,684 additions and 3 deletions.
9,662 changes: 9,662 additions & 0 deletions results/bedrock_claude/annotation_bedrock_claude.json

Large diffs are not rendered by default.

4,832 changes: 4,832 additions & 0 deletions results/bedrock_claude/model_outputs.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
]
PACKAGES_ANALYSIS = ["seaborn", "matplotlib", "jupyterlab"]
PACKAGES_LOCAL = ["accelerate", "transformers", "bitsandbytes", "xformers", "peft", "optimum", "scipy", "einops"]
PACKAGES_ALL_API = ["anthropic>=0.3.3", "huggingface_hub", "cohere", "replicate"]
PACKAGES_ALL_API = ["anthropic>=0.3.3", "huggingface_hub", "cohere", "replicate", "boto3>=1.28.58"]
PACKAGES_ALL = PACKAGES_LOCAL + PACKAGES_ALL_API + PACKAGES_ANALYSIS + PACKAGES_DEV

setuptools.setup(
Expand Down
10 changes: 10 additions & 0 deletions src/alpaca_eval/decoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ def get_fn_completions(name: Union[str, Callable]) -> Callable:
packages = ["vllm", "ray", "transformers"]
logging.exception(f"You need {packages} to use vllm_completions. Error:")
raise e

elif name == "bedrock_anthropic_completions":
try:
from .bedrock_anthropic import bedrock_anthropic_completions

return bedrock_anthropic_completions
except ImportError as e:
packages = ["boto3"]
logging.exception(f"You need {packages} to use bedrock_anthropic. Error:")
raise e

else:
raise ValueError(f"Unknown decoder: {name}")
132 changes: 132 additions & 0 deletions src/alpaca_eval/decoders/bedrock_anthropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import copy
import functools
import logging
import multiprocessing
import random
import time
from typing import Optional, Sequence, Union

import botocore.exceptions
import boto3
import numpy as np
import tqdm
import json

from .. import utils

__all__ = ["bedrock_anthropic_completions"]

DEFAULT_NUM_PROCS = 3

def bedrock_anthropic_completions(
prompts: Sequence[str],
max_tokens_to_sample: Union[int, Sequence[int]] = 2048,
model_name: str = "anthropic.claude-v1",
num_procs: int = DEFAULT_NUM_PROCS,
**decoding_kwargs,
) -> dict[str, list]:
"""Decode with Anthropic API.
Parameters
----------
prompts : list of str
Prompts to get completions for.
model_name : str, optional
Name of the model to use for decoding.
num_procs : int, optional
Number of parallel processes to use for decoding.
decoding_kwargs :
Additional kwargs to pass to Bedrock Anthropic.
"""
num_procs = num_procs or DEFAULT_NUM_PROCS

n_examples = len(prompts)
if n_examples == 0:
logging.info("No samples to annotate.")
return []
else:
to_log = f"Using `bedrock_anthropic_completions` on {n_examples} prompts using {model_name} and num_procs={num_procs}."
logging.info(to_log)

if isinstance(max_tokens_to_sample, int):
max_tokens_to_sample = [max_tokens_to_sample] * n_examples

inputs = zip(prompts, max_tokens_to_sample)

kwargs = dict(model_name=model_name, **decoding_kwargs)
kwargs_to_log = {k: v for k, v in kwargs.items() if "api_key" not in k}
logging.info(f"Kwargs to completion: {kwargs_to_log}")
with utils.Timer() as t:
if num_procs == 1:
responses = [_bedrock_anthropic_completion_helper(inp, **kwargs) for inp in tqdm.tqdm(inputs, desc="prompts")]
else:
with multiprocessing.Pool(num_procs) as p:
partial_completion_helper = functools.partial(_bedrock_anthropic_completion_helper, **kwargs)
responses = list(
tqdm.tqdm(
p.imap(partial_completion_helper, inputs),
desc="prompts",
total=len(prompts),
)
)
logging.info(f"Completed {n_examples} examples in {t}.")

completions = responses

## Token counts are not returned by Bedrock for now
price = [0 for _ in prompts]

avg_time = [t.duration / n_examples] * len(completions)

return dict(completions=completions, price_per_example=price, time_per_example=avg_time, completions_all=responses)


def _bedrock_anthropic_completion_helper(
args: tuple[str, int],
sleep_time: int = 2,
region: Optional[str] = 'us-west-2',
model_name: str = "anthropic.claude-v1",
temperature: Optional[float] = 0.7,
**kwargs,
):
prompt, max_tokens = args

if not utils.check_pkg_atleast_version("boto3", "1.28.58"):
raise ValueError("boto3 version must be at least 1.28.58 Use `pip install -U boto3`.")

bedrock = boto3.client(
service_name='bedrock-runtime',
region_name=region
)
accept = 'application/json'
contentType = 'application/json'

kwargs.update(dict(max_tokens_to_sample=max_tokens, temperature=temperature))
curr_kwargs = copy.deepcopy(kwargs)
while True:
try:
body = json.dumps(
{
**{
'prompt':prompt
},
**curr_kwargs}
)
response = bedrock.invoke_model(
body=body, modelId=model_name, accept=accept, contentType=contentType
)
response = json.loads(response.get('body').read()).get('completion')
break
except botocore.exceptions.ClientError as e:
if e.response['Error']['Code'] == 'ThrottlingException':
logging.warning(f"Hit throttling error: {e}.")
logging.warning(f"Rate limit hit. Sleeping for {sleep_time} seconds.")
time.sleep(sleep_time)
except Exception as e:
logging.error(f'Hit unknown error : {e}')
raise e

return response
12 changes: 12 additions & 0 deletions src/alpaca_eval/evaluators_configs/bedrock_claude/configs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
bedrock_claude:
prompt_template: "claude/basic_prompt.txt"
fn_completions: "bedrock_anthropic_completions"
completions_kwargs:
model_name: "anthropic.claude-v1"
max_tokens_to_sample: 50
temperature: 0
completion_parser_kwargs:
outputs_to_match:
1: '(?:^|\n) ?Output \(a\)'
2: '(?:^|\n) ?Output \(b\)'
batch_size: 1
12 changes: 12 additions & 0 deletions src/alpaca_eval/evaluators_configs/bedrock_claude_2/configs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
bedrock_claude_2:
prompt_template: "claude/basic_prompt.txt"
fn_completions: "bedrock_anthropic_completions"
completions_kwargs:
model_name: "anthropic.claude-v2"
max_tokens_to_sample: 50
temperature: 0
completion_parser_kwargs:
outputs_to_match:
1: '(?:^|\n) ?Output \(a\)'
2: '(?:^|\n) ?Output \(b\)'
batch_size: 1
7 changes: 7 additions & 0 deletions src/alpaca_eval/models_configs/bedrock_claude/configs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
bedrock_claude:
prompt_template: "claude/prompt.txt"
fn_completions: "bedrock_anthropic_completions"
completions_kwargs:
model_name: "anthropic.claude-v1"
max_tokens_to_sample: 2048
pretty_name: "Bedrock Claude"
7 changes: 7 additions & 0 deletions src/alpaca_eval/models_configs/bedrock_claude_2/configs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
bedrock_claude_2:
prompt_template: "claude/prompt.txt"
fn_completions: "bedrock_anthropic_completions"
completions_kwargs:
model_name: "anthropic.claude-v2"
max_tokens_to_sample: 2048
pretty_name: "Bedrock Claude 2"
11 changes: 9 additions & 2 deletions tests/integration_tests/test_decoders_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from alpaca_eval.decoders.huggingface_api import huggingface_api_completions
from alpaca_eval.decoders.huggingface_local import huggingface_local_completions
from alpaca_eval.decoders.openai import openai_completions

from alpaca_eval.decoders.bedrock_anthropic import bedrock_anthropic_completions

def _get_formatted_prompts(model):
filename = list((constants.MODELS_CONFIG_DIR / model).glob("*.txt"))[0]
Expand All @@ -17,7 +17,6 @@ def _get_formatted_prompts(model):
prompts = [template.format(instruction=prompt) for prompt in prompts]
return prompts


@pytest.mark.slow
def test_openai_completions_integration():
prompts = _get_formatted_prompts("gpt4")
Expand Down Expand Up @@ -72,3 +71,11 @@ def test_vllm_local_completions_integration():
prompts, model_name="OpenBuddy/openbuddy-openllama-3b-v10-bf16", max_new_tokens=100
)
assert len(results["completions"]) == len(prompts)

@pytest.mark.slow
def test_bedrock_anthropic_completions_integration():
prompts = _get_formatted_prompts("claude")
results = bedrock_anthropic_completions(prompts)
assert len(results["completions"]) == len(prompts)
assert "2" in results["completions"][0]
assert "4" in results["completions"][1]

0 comments on commit 6e6d11f

Please sign in to comment.