Skip to content

Commit

Permalink
Formatted model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
viveksilimkhan1 authored Oct 25, 2023
1 parent 8b032de commit 621d2d1
Showing 1 changed file with 32 additions and 15 deletions.
47 changes: 32 additions & 15 deletions spacy_llm/models/bedrock/model.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import os
import json
import os
import warnings
from enum import Enum
from typing import Any, Dict, Iterable, Optional, List
from typing import Any, Dict, Iterable, List, Optional


class Models(str, Enum):
# Completion models
TITAN_EXPRESS = "amazon.titan-text-express-v1"
TITAN_LITE = "amazon.titan-text-lite-v1"

class Bedrock():

class Bedrock:
def __init__(
self,
model_id: str,
region: str,
config: Dict[Any, Any],
max_retries: int = 5
self, model_id: str, region: str, config: Dict[Any, Any], max_retries: int = 5
):
self._region = region
self._model_id = model_id
Expand All @@ -25,7 +23,7 @@ def __init__(
def get_session_kwargs(self) -> Dict[str, Optional[str]]:

# Fetch and check the credentials
profile = os.getenv("AWS_PROFILE") if not None else ""
profile = os.getenv("AWS_PROFILE") if not None else ""
secret_key_id = os.getenv("AWS_ACCESS_KEY_ID")
secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY")
session_token = os.getenv("AWS_SESSION_TOKEN")
Expand Down Expand Up @@ -58,7 +56,13 @@ def get_session_kwargs(self) -> Dict[str, Optional[str]]:
assert secret_access_key is not None
assert session_token is not None

session_kwargs = {"profile_name":profile, "region_name":self._region, "aws_access_key_id":secret_key_id, "aws_secret_access_key":secret_access_key, "aws_session_token":session_token}
session_kwargs = {
"profile_name": profile,
"region_name": self._region,
"aws_access_key_id": secret_key_id,
"aws_secret_access_key": secret_access_key,
"aws_session_token": session_token,
}
return session_kwargs

def __call__(self, prompts: Iterable[str]) -> Iterable[str]:
Expand All @@ -69,23 +73,36 @@ def _request(json_data: str) -> str:
try:
import boto3
except ImportError as err:
warnings.warn("To use Bedrock, you need to install boto3. Use pip install boto3 ")
warnings.warn(
"To use Bedrock, you need to install boto3. Use pip install boto3 "
)
raise err
from botocore.config import Config

session_kwargs = self.get_session_kwargs()
session = boto3.Session(**session_kwargs)
api_config = Config(retries = dict(max_attempts = self._max_retries))
api_config = Config(retries=dict(max_attempts=self._max_retries))
bedrock = session.client(service_name="bedrock-runtime", config=api_config)
accept = "application/json"
contentType = "application/json"
r = bedrock.invoke_model(body=json_data, modelId=self._model_id, accept=accept, contentType=contentType)
responses = json.loads(r["body"].read().decode())["results"][0]["outputText"]
r = bedrock.invoke_model(
body=json_data,
modelId=self._model_id,
accept=accept,
contentType=contentType,
)
responses = json.loads(r["body"].read().decode())["results"][0][
"outputText"
]
return responses

for prompt in prompts:
if self._model_id in [Models.TITAN_LITE, Models.TITAN_EXPRESS]:
responses = _request(json.dumps({"inputText": prompt, "textGenerationConfig":self._config}))
responses = _request(
json.dumps(
{"inputText": prompt, "textGenerationConfig": self._config}
)
)

api_responses.append(responses)

Expand Down

0 comments on commit 621d2d1

Please sign in to comment.