From bb2f46d7cc9ecf92e297d5655c5514429b72a693 Mon Sep 17 00:00:00 2001 From: "Dr.MerdanBay" <110794035+KMerdan@users.noreply.github.com> Date: Fri, 20 Dec 2024 12:13:39 +0900 Subject: [PATCH] fix: add safe dictionary access for bedrock credentials (#11860) --- .../bedrock/get_bedrock_client.py | 16 ++++++++++++---- .../model_providers/bedrock/rerank/rerank.py | 5 ++++- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py index a19ffbb20a6a9e..2ad37cef3b38f1 100644 --- a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py +++ b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py @@ -1,11 +1,19 @@ +from collections.abc import Mapping + import boto3 from botocore.config import Config +from core.model_runtime.errors.invoke import InvokeBadRequestError + + +def get_bedrock_client(service_name: str, credentials: Mapping[str, str]): + region_name = credentials.get("aws_region") + if not region_name: + raise InvokeBadRequestError("aws_region is required") + client_config = Config(region_name=region_name) + aws_access_key_id = credentials.get("aws_access_key_id") + aws_secret_access_key = credentials.get("aws_secret_access_key") -def get_bedrock_client(service_name, credentials=None): - client_config = Config(region_name=credentials["aws_region"]) - aws_access_key_id = credentials["aws_access_key_id"] - aws_secret_access_key = credentials["aws_secret_access_key"] if aws_access_key_id and aws_secret_access_key: # use aksk to call bedrock client = boto3.client( diff --git a/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py b/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py index e134db646f3d39..9da23ba1b0f08f 100644 --- a/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py @@ -62,7 +62,10 @@ def _invoke( } ) modelId = model - region = credentials["aws_region"] + region = credentials.get("aws_region") + # region is a required field + if not region: + raise InvokeBadRequestError("aws_region is required in credentials") model_package_arn = f"arn:aws:bedrock:{region}::foundation-model/{modelId}" rerankingConfiguration = { "type": "BEDROCK_RERANKING_MODEL",