Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for AWS Bedrock as LLM provider #61

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,19 @@ source ./venv/bin/activate
```
and run: `pip install -r requirements.txt`.

(2) Duplicate the file `alpha_codium/settings/.secrets_template.toml`, rename it as `alpha_codium/settings/.secrets.toml`, and fill in your OpenAI API key:
(2) Determine your model provider.

For OpenAI:
- Duplicate the file `alpha_codium/settings/.secrets_template.toml`, rename it as `alpha_codium/settings/.secrets.toml`, and fill in your OpenAI API key:
```
[openai]
key = "..."
```
- Verify that one of the "chat" models is selected in 'alpha_codium/settings/configuration.toml'.

For AWS Bedrock:
- log in to AWS using your standard AWS CLI method. LiteLLM will use boto3 to connect to AWS and so either your profile credentials from ~/.aws/credentials or the environment values AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION_NAME.
- ensure model in 'alpha_codium/settings/configuration.toml' is 'bedrock/<aws model id>'.

(3) Download the processed CodeContest validation and test dataset from [hugging face](https://huggingface.co/datasets/talrid/CodeContests_valid_and_test_AlphaCodium/blob/main/codecontests_valid_and_test_processed_alpha_codium.zip), extract the zip file, and placed the extracted folder in the root of the project.

Expand Down Expand Up @@ -127,7 +135,7 @@ To solve a custom problem with AlphaCodium, first create a json file that includ
python -m alpha_codium.solve_my_problem \
--my_problem_json_file /path/to/my_problem.json
```
- The `my_problem_json_file` is the path to to the custom problem json file.
- The `my_problem_json_file` is the path to the custom problem json file.

See the `my_problem_example.json` to see an example of a custom problem. The json file should include the following fields:
- `name` is the name of the problem.
Expand Down
161 changes: 100 additions & 61 deletions alpha_codium/llm/ai_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
from enum import Enum

import litellm
import openai
Expand All @@ -16,49 +17,49 @@
logger = get_logger(__name__)
OPENAI_RETRIES = 5

class Provider(Enum):
OPENAI = "openai"
DEEPSEEK_LEGACY = "deepseek-legacy"
AWSBEDROCK = "awsbedrock"
UNKNOWN = "unknown"

class AiHandler:
"""
This class handles interactions with the OpenAI API for chat completions.
It initializes the API key and other settings from a configuration file,
and provides a method for performing chat completions using the OpenAI ChatCompletion API.
"""

def __init__(self):
self.provider = Provider.UNKNOWN
"""
Initializes the OpenAI API key and other settings from a configuration file.
Raises a ValueError if the OpenAI key is missing.
"""
self.limiter = AsyncLimiter(get_settings().config.max_requests_per_minute)
try:
if "gpt" in get_settings().get("config.model").lower():
if get_settings().get("config.model").lower().startswith("bedrock"):
self.provider = Provider.AWSBEDROCK
elif "gpt" in get_settings().get("config.model").lower():
self.provider = Provider.OPENAI
try:
openai.api_key = get_settings().openai.key
litellm.openai_key = get_settings().openai.key
self.azure = False
if "deepseek" in get_settings().get("config.model"):
litellm.register_prompt_template(
model="huggingface/deepseek-ai/deepseek-coder-33b-instruct",
roles={
"system": {
"pre_message": "",
"post_message": "\n"
},
"user": {
"pre_message": "### Instruction:\n",
"post_message": "\n### Response:\n"
},
except AttributeError as e:
raise ValueError("OpenAI key is required") from e
elif "deepseek" in get_settings().get("config.model"):
self.provider = Provider.DEEPSEEK_LEGACY
litellm.register_prompt_template(
model="huggingface/deepseek-ai/deepseek-coder-33b-instruct",
roles={
"system": {
"pre_message": "",
"post_message": "\n"
},
"user": {
"pre_message": "### Instruction:\n",
"post_message": "\n### Response:\n"
},
},

)
except AttributeError as e:
raise ValueError("OpenAI key is required") from e
)

@property
def deployment_id(self):
"""
Returns the deployment ID for the OpenAI API.
"""
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
self.azure = False
litellm.set_verbose=get_settings().get("litellm.set_verbose", False)
litellm.drop_params=get_settings().get("litellm.drop_params", False)

@retry(
exceptions=(AttributeError, RateLimitError),
Expand All @@ -75,46 +76,33 @@ async def chat_completion(
frequency_penalty: float = 0.0,
):
try:
deployment_id = self.deployment_id
if get_settings().config.verbosity_level >= 2:
logging.debug(
f"Generating completion with {model}"
f"{(' from deployment ' + deployment_id) if deployment_id else ''}"
)

async with self.limiter:
logger.info("-----------------")
logger.info("Running inference ...")
logger.info(f"Running inference ... provider: {self.provider}, model: {model}")
logger.debug(f"system:\n{system}")
logger.debug(f"user:\n{user}")
if "deepseek" in get_settings().get("config.model"):
response = await acompletion(
model="huggingface/deepseek-ai/deepseek-coder-33b-instruct",
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user},
],
api_base=get_settings().get("config.model"),
if self.provider == Provider.DEEPSEEK_LEGACY:
response = await self._deepseek_chat_completion(
model=model,
system=system,
user=user,
temperature=temperature,
repetition_penalty=frequency_penalty+1, # the scale of TGI is different from OpenAI
force_timeout=get_settings().config.ai_timeout,
max_tokens=2000,
stop=['<|EOT|>'],
frequency_penalty=frequency_penalty,
)
response["choices"][0]["message"]["content"] = response["choices"][0]["message"]["content"].rstrip()
if response["choices"][0]["message"]["content"].endswith("<|EOT|>"):
response["choices"][0]["message"]["content"] = response["choices"][0]["message"]["content"][:-7]
else:
response = await acompletion(
elif self.provider == Provider.OPENAI:
response = await self._openai_chat_completion(
model=model,
deployment_id=deployment_id,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user},
],
system=system,
user=user,
temperature=temperature,
frequency_penalty=frequency_penalty,
force_timeout=get_settings().config.ai_timeout,
)
else:
response = await self._awsbedrock_chat_completion(
model=model,
system=system,
user=user,
temperature=temperature
)
except (APIError) as e:
logging.error("Error during OpenAI inference")
Expand All @@ -133,3 +121,54 @@ async def chat_completion(
logger.info('done')
logger.info("-----------------")
return resp, finish_reason

async def _deepseek_chat_completion(self, model, system, user, temperature, frequency_penalty):
response = await acompletion(
model="huggingface/deepseek-ai/deepseek-coder-33b-instruct",
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user},
],
api_base=get_settings().get("config.model"),
temperature=temperature,
repetition_penalty=frequency_penalty+1, # the scale of TGI is different from OpenAI
force_timeout=get_settings().config.ai_timeout,
max_tokens=2000,
stop=['<|EOT|>'],
)
response["choices"][0]["message"]["content"] = response["choices"][0]["message"]["content"].rstrip()
if response["choices"][0]["message"]["content"].endswith("<|EOT|>"):
response["choices"][0]["message"]["content"] = response["choices"][0]["message"]["content"][:-7]
return response

async def _openai_chat_completion(self, model, system, user, temperature, frequency_penalty):
deployment_id = get_settings().get("OPENAI.DEPLOYMENT_ID", None)
if get_settings().config.verbosity_level >= 2:
logging.debug(
f"Generating completion with {model}"
f"{(' from deployment ' + deployment_id) if deployment_id else ''}"
)
response = await acompletion(
model=model,
deployment_id=deployment_id,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user},
],
temperature=temperature,
frequency_penalty=frequency_penalty,
force_timeout=get_settings().config.ai_timeout,
)
return response

async def _awsbedrock_chat_completion(self, model, system, user, temperature):
response = await acompletion(
model=model,
user=user,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user},
],
temperature=temperature
)
return response
9 changes: 7 additions & 2 deletions alpha_codium/settings/configuration.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
[config]
model="gpt-4-0125-preview"
#model="bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"
#model="gpt-4-0125-preview"
# model="gpt-4o-2024-05-13"
# model="gpt-4-0613"
# model="gpt-3.5-turbo-16k"
model="gpt-3.5-turbo-16k"
frequency_penalty=0.1
ai_timeout=90 # seconds
fallback_models =[]
verbosity_level=0 # 0,1,2
private_dataset_cache_dir="~/.cache/huggingface/datasets/alpha_codium"
max_requests_per_minute=60

[litellm]
# set_verbose=true
# drop_params=true

[dataset]
evaluate_prev_solutions=false
num_iterations=1 # X iterations to try to solve the problem
Expand Down
10 changes: 5 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
dynaconf==3.1.12
fastapi==0.99.0
fastapi==0.115.0
PyGithub==1.59.*
retry==0.9.2
Jinja2==3.1.2
tiktoken==0.5.2
tiktoken==0.7.*
uvicorn==0.22.0
pytest==7.4.0
aiohttp==3.9.3
atlassian-python-api==3.39.0
GitPython==3.1.32
PyYAML==6.0.1
starlette-context==0.3.6
boto3==1.28.25
boto3==1.28.*
google-cloud-storage==2.10.0
ujson==5.8.0
azure-devops==7.1.0b3
msrest==0.7.1
##
openai
litellm
duckdb==0.9.2
litellm==1.48.2
duckdb==1.1.1
datasets
notebook
black
Expand Down