diff --git a/spacy_llm/models/bedrock/model.py b/spacy_llm/models/bedrock/model.py index 46e0fa21..24f2df46 100644 --- a/spacy_llm/models/bedrock/model.py +++ b/spacy_llm/models/bedrock/model.py @@ -8,14 +8,6 @@ from ...registry import registry -try: - import boto3 - import botocore - from botocore.config import Config -except ImportError as err: - print("To use Bedrock, you need to install boto3. Use `pip install boto3` ") - raise err - class Models(str, Enum): # Completion models TITAN_EXPRESS = "amazon.titan-text-express-v1" @@ -36,7 +28,7 @@ def __init__( self._max_retries = max_retries # @property - def get_session(self): + def get_session(self) -> Dict[str, str]: # Fetch and check the credentials profile = os.getenv("AWS_PROFILE") if not None else "" @@ -71,22 +63,28 @@ def get_session(self): assert session_token is not None session_kwargs = {"profile_name":profile, "region_name":self._region, "aws_access_key_id":secret_key_id, "aws_secret_access_key":secret_access_key, "aws_session_token":session_token} - bedrock = boto3.Session(**session_kwargs) - return bedrock + return session_kwargs def __call__(self, prompts: Iterable[str]) -> Iterable[str]: api_responses: List[str] = [] prompts = list(prompts) - api_config = Config(retries = dict(max_attempts = self._max_retries)) - def _request(json_data: str) -> str: - session = self.get_session() + def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: + try: + import boto3 + import botocore + except ImportError as err: + print("To use Bedrock, you need to install boto3. Use `pip install boto3` ") + raise err + from botocore.config import Config + session_kwargs = self.get_session_kwargs() + session = boto3.Session(**session_kwargs) print("Session:", session) bedrock = session.client(service_name="bedrock-runtime", config=api_config) - accept = 'application/json' - contentType = 'application/json' + accept = "application/json" + contentType = "application/json" r = bedrock.invoke_model(body=json_data, modelId=self._model_id, accept=accept, contentType=contentType) - responses = json.loads(r['body'].read().decode())['results'][0]['outputText'] + responses = json.loads(r["body"].read().decode())["results"][0]["outputText"] return responses for prompt in prompts: