diff --git a/spacy_llm/models/bedrock/model.py b/spacy_llm/models/bedrock/model.py index 2135677e..53c23cd0 100644 --- a/spacy_llm/models/bedrock/model.py +++ b/spacy_llm/models/bedrock/model.py @@ -2,11 +2,7 @@ import json import warnings from enum import Enum -from typing import Any, Dict, Iterable, Optional, Type, List, Sized, Tuple - -from confection import SimpleFrozenDict - -from ...registry import registry +from typing import Any, Dict, Iterable, Optional, List class Models(str, Enum): # Completion models @@ -29,7 +25,7 @@ def __init__( # @property def get_session_kwargs(self) -> Dict[str, Optional[str]]: - + # Fetch and check the credentials profile = os.getenv("AWS_PROFILE") if not None else "" secret_key_id = os.getenv("AWS_ACCESS_KEY_ID") @@ -65,14 +61,13 @@ def get_session_kwargs(self) -> Dict[str, Optional[str]]: session_kwargs = {"profile_name":profile, "region_name":self._region, "aws_access_key_id":secret_key_id, "aws_secret_access_key":secret_access_key, "aws_session_token":session_token} return session_kwargs - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: + def __call__(self, prompts: Iterable[str]) -> Iterable[str]: api_responses: List[str] = [] prompts = list(prompts) def _request(json_data: str) -> str: try: import boto3 - import botocore except ImportError as err: print("To use Bedrock, you need to install boto3. Use `pip install boto3` ") raise err @@ -80,7 +75,6 @@ def _request(json_data: str) -> str: session_kwargs = self.get_session_kwargs() session = boto3.Session(**session_kwargs) api_config = Config(retries = dict(max_attempts = self._max_retries)) - print("Session:", session) bedrock = session.client(service_name="bedrock-runtime", config=api_config) accept = "application/json" contentType = "application/json"