From f99c82b98448b004b5bd220a1d1af3e632a1778c Mon Sep 17 00:00:00 2001 From: Vivek Silimkhan Date: Tue, 31 Oct 2023 01:31:27 +0530 Subject: [PATCH] Fix --- spacy_llm/models/rest/bedrock/model.py | 30 +++++++++++--------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/spacy_llm/models/rest/bedrock/model.py b/spacy_llm/models/rest/bedrock/model.py index 3ab1b7cc..2592206f 100644 --- a/spacy_llm/models/rest/bedrock/model.py +++ b/spacy_llm/models/rest/bedrock/model.py @@ -2,7 +2,7 @@ import os import warnings from enum import Enum -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Tuple from ..base import REST @@ -16,17 +16,15 @@ class Models(str, Enum): AI21_JURASSIC_MID = "ai21.j2-mid-v1" -class ModelParams(List, Enum): - # Params with default values - TITAN = ["maxTokenCount", "stopSequences", "temperature", "topP"] - AI21_JURASSIC = [ - "maxTokens", - "temperature", - "topP", - "countPenalty", - "presencePenalty", - "frequencyPenalty", - ] +TITAN_PARAMS = ["maxTokenCount", "stopSequences", "temperature", "topP"] +AI21_JURASSIC_PARAMS = [ + "maxTokens", + "temperature", + "topP", + "countPenalty", + "presencePenalty", + "frequencyPenalty", +] class Bedrock(REST): @@ -45,9 +43,9 @@ def __init__( self._config = {} if self._model_id in [Models.TITAN_EXPRESS, Models.TITAN_LITE]: - config_params = ModelParams.TITAN + config_params = TITAN_PARAMS if self._model_id in [Models.AI21_JURASSIC_ULTRA, Models.AI21_JURASSIC_MID]: - config_params = ModelParams.AI21_JURASSIC + config_params = AI21_JURASSIC_PARAMS for i in config_params: self._config[i] = config[i] @@ -164,14 +162,12 @@ def _request(json_data: str) -> str: def _verify_auth(self) -> None: try: import boto3 - from botocore.config import Config from botocore.exceptions import NoCredentialsError - from botocore.exceptions import ClientError session_kwargs = self.get_session_kwargs() session = boto3.Session(**session_kwargs) bedrock = session.client(service_name="bedrock") - models = bedrock.list_foundation_models() + bedrock.list_foundation_models() except NoCredentialsError: raise NoCredentialsError