From 621d2d14f6461afed1963115d22527e0301c017e Mon Sep 17 00:00:00 2001 From: Vivek Silimkhan <126159777+viveksilimkhan1@users.noreply.github.com> Date: Wed, 25 Oct 2023 17:02:14 +0530 Subject: [PATCH] Formatted model.py --- spacy_llm/models/bedrock/model.py | 47 +++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/spacy_llm/models/bedrock/model.py b/spacy_llm/models/bedrock/model.py index 7fa0af4e..2ebec5b1 100644 --- a/spacy_llm/models/bedrock/model.py +++ b/spacy_llm/models/bedrock/model.py @@ -1,21 +1,19 @@ -import os import json +import os import warnings from enum import Enum -from typing import Any, Dict, Iterable, Optional, List +from typing import Any, Dict, Iterable, List, Optional + class Models(str, Enum): # Completion models TITAN_EXPRESS = "amazon.titan-text-express-v1" TITAN_LITE = "amazon.titan-text-lite-v1" -class Bedrock(): + +class Bedrock: def __init__( - self, - model_id: str, - region: str, - config: Dict[Any, Any], - max_retries: int = 5 + self, model_id: str, region: str, config: Dict[Any, Any], max_retries: int = 5 ): self._region = region self._model_id = model_id @@ -25,7 +23,7 @@ def __init__( def get_session_kwargs(self) -> Dict[str, Optional[str]]: # Fetch and check the credentials - profile = os.getenv("AWS_PROFILE") if not None else "" + profile = os.getenv("AWS_PROFILE") if not None else "" secret_key_id = os.getenv("AWS_ACCESS_KEY_ID") secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY") session_token = os.getenv("AWS_SESSION_TOKEN") @@ -58,7 +56,13 @@ def get_session_kwargs(self) -> Dict[str, Optional[str]]: assert secret_access_key is not None assert session_token is not None - session_kwargs = {"profile_name":profile, "region_name":self._region, "aws_access_key_id":secret_key_id, "aws_secret_access_key":secret_access_key, "aws_session_token":session_token} + session_kwargs = { + "profile_name": profile, + "region_name": self._region, + "aws_access_key_id": secret_key_id, + "aws_secret_access_key": secret_access_key, + "aws_session_token": session_token, + } return session_kwargs def __call__(self, prompts: Iterable[str]) -> Iterable[str]: @@ -69,23 +73,36 @@ def _request(json_data: str) -> str: try: import boto3 except ImportError as err: - warnings.warn("To use Bedrock, you need to install boto3. Use pip install boto3 ") + warnings.warn( + "To use Bedrock, you need to install boto3. Use pip install boto3 " + ) raise err from botocore.config import Config session_kwargs = self.get_session_kwargs() session = boto3.Session(**session_kwargs) - api_config = Config(retries = dict(max_attempts = self._max_retries)) + api_config = Config(retries=dict(max_attempts=self._max_retries)) bedrock = session.client(service_name="bedrock-runtime", config=api_config) accept = "application/json" contentType = "application/json" - r = bedrock.invoke_model(body=json_data, modelId=self._model_id, accept=accept, contentType=contentType) - responses = json.loads(r["body"].read().decode())["results"][0]["outputText"] + r = bedrock.invoke_model( + body=json_data, + modelId=self._model_id, + accept=accept, + contentType=contentType, + ) + responses = json.loads(r["body"].read().decode())["results"][0][ + "outputText" + ] return responses for prompt in prompts: if self._model_id in [Models.TITAN_LITE, Models.TITAN_EXPRESS]: - responses = _request(json.dumps({"inputText": prompt, "textGenerationConfig":self._config})) + responses = _request( + json.dumps( + {"inputText": prompt, "textGenerationConfig": self._config} + ) + ) api_responses.append(responses)