Skip to content

Commit

Permalink
fix: add safe dictionary access for bedrock credentials (#11860)
Browse files Browse the repository at this point in the history
  • Loading branch information
KMerdan authored Dec 20, 2024
1 parent 463fbe2 commit bb2f46d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from collections.abc import Mapping

import boto3
from botocore.config import Config

from core.model_runtime.errors.invoke import InvokeBadRequestError


def get_bedrock_client(service_name: str, credentials: Mapping[str, str]):
region_name = credentials.get("aws_region")
if not region_name:
raise InvokeBadRequestError("aws_region is required")
client_config = Config(region_name=region_name)
aws_access_key_id = credentials.get("aws_access_key_id")
aws_secret_access_key = credentials.get("aws_secret_access_key")

def get_bedrock_client(service_name, credentials=None):
client_config = Config(region_name=credentials["aws_region"])
aws_access_key_id = credentials["aws_access_key_id"]
aws_secret_access_key = credentials["aws_secret_access_key"]
if aws_access_key_id and aws_secret_access_key:
# use aksk to call bedrock
client = boto3.client(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ def _invoke(
}
)
modelId = model
region = credentials["aws_region"]
region = credentials.get("aws_region")
# region is a required field
if not region:
raise InvokeBadRequestError("aws_region is required in credentials")
model_package_arn = f"arn:aws:bedrock:{region}::foundation-model/{modelId}"
rerankingConfiguration = {
"type": "BEDROCK_RERANKING_MODEL",
Expand Down

0 comments on commit bb2f46d

Please sign in to comment.