Skip to content

Commit

Permalink
Update model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
viveksilimkhan1 authored Oct 24, 2023
1 parent 548ce74 commit 77b3c2c
Showing 1 changed file with 15 additions and 17 deletions.
32 changes: 15 additions & 17 deletions spacy_llm/models/bedrock/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,6 @@

from ...registry import registry

try:
import boto3
import botocore
from botocore.config import Config
except ImportError as err:
print("To use Bedrock, you need to install boto3. Use `pip install boto3` ")
raise err

class Models(str, Enum):
# Completion models
TITAN_EXPRESS = "amazon.titan-text-express-v1"
Expand All @@ -36,7 +28,7 @@ def __init__(
self._max_retries = max_retries

# @property
def get_session(self):
def get_session(self) -> Dict[str, str]:

# Fetch and check the credentials
profile = os.getenv("AWS_PROFILE") if not None else ""
Expand Down Expand Up @@ -71,22 +63,28 @@ def get_session(self):
assert session_token is not None

session_kwargs = {"profile_name":profile, "region_name":self._region, "aws_access_key_id":secret_key_id, "aws_secret_access_key":secret_access_key, "aws_session_token":session_token}
bedrock = boto3.Session(**session_kwargs)
return bedrock
return session_kwargs

def __call__(self, prompts: Iterable[str]) -> Iterable[str]:
api_responses: List[str] = []
prompts = list(prompts)
api_config = Config(retries = dict(max_attempts = self._max_retries))

def _request(json_data: str) -> str:
session = self.get_session()
def _request(json_data: Dict[str, Any]) -> Dict[str, Any]:
try:
import boto3
import botocore
except ImportError as err:
print("To use Bedrock, you need to install boto3. Use `pip install boto3` ")
raise err
from botocore.config import Config
session_kwargs = self.get_session_kwargs()
session = boto3.Session(**session_kwargs)
print("Session:", session)
bedrock = session.client(service_name="bedrock-runtime", config=api_config)
accept = 'application/json'
contentType = 'application/json'
accept = "application/json"
contentType = "application/json"
r = bedrock.invoke_model(body=json_data, modelId=self._model_id, accept=accept, contentType=contentType)
responses = json.loads(r['body'].read().decode())['results'][0]['outputText']
responses = json.loads(r["body"].read().decode())["results"][0]["outputText"]
return responses

for prompt in prompts:
Expand Down

0 comments on commit 77b3c2c

Please sign in to comment.