From 3a2cbaeb01a2d8b35af6041442cab9b595da5474 Mon Sep 17 00:00:00 2001 From: Alex-Wenner-FHR <78372056+Alex-Wenner-FHR@users.noreply.github.com> Date: Wed, 18 Dec 2024 14:33:05 +0000 Subject: [PATCH 1/9] adds azure details to readme --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index 77f5d43..a6c9c16 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,14 @@ You can sign up for an OpenAI account [here](https://platform.openai.com/) and m Important: Free-tier OpenAI accounts may be subject to rate limits, which could affect AG-A's performance. We recommend using a paid OpenAI API key for seamless functionality. +#### Azure OpenAI Setup +To use Azure OpenAI, you'll need to set the following Azure OpenAI values, as environment variables: +```bash +export AZURE_OPENAI_API_KEY=<...> +export OPENAI_API_VERSION=<...> +export AZURE_OPENAI_ENDPOINT=<...> +``` + ## Usage We support two ways of using AutoGluon Assistant: WebUI and CLI. From 1c2d5ab62155d2592360428052ed9799ee65b604 Mon Sep 17 00:00:00 2001 From: Alex-Wenner-FHR <78372056+Alex-Wenner-FHR@users.noreply.github.com> Date: Wed, 18 Dec 2024 14:33:24 +0000 Subject: [PATCH 2/9] adds AssistantAzureChatOpenAI to imports and return type --- src/autogluon/assistant/assistant.py | 35 +++++++++++++++++++++------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/src/autogluon/assistant/assistant.py b/src/autogluon/assistant/assistant.py index 708b1dd..02540c9 100644 --- a/src/autogluon/assistant/assistant.py +++ b/src/autogluon/assistant/assistant.py @@ -8,7 +8,12 @@ from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf -from autogluon.assistant.llm import AssistantChatBedrock, AssistantChatOpenAI, LLMFactory +from autogluon.assistant.llm import ( + AssistantChatBedrock, + AssistantChatOpenAI, + AssistantAzureChatOpenAI, + LLMFactory, +) from .predictor import AutogluonTabularPredictor from .task import TabularPredictionTask @@ -31,7 +36,9 @@ def timeout(seconds: int, error_message: Optional[str] = None): if sys.platform == "win32": # Windows implementation using threading - timer = threading.Timer(seconds, lambda: (_ for _ in ()).throw(TimeoutError(error_message))) + timer = threading.Timer( + seconds, lambda: (_ for _ in ()).throw(TimeoutError(error_message)) + ) timer.start() try: yield @@ -55,7 +62,9 @@ class TabularPredictionAssistant: def __init__(self, config: DictConfig) -> None: self.config = config - self.llm: Union[AssistantChatOpenAI, AssistantChatBedrock] = LLMFactory.get_chat_model(config.llm) + self.llm: Union[ + AssistantChatOpenAI, AssistantAzureChatOpenAI, AssistantChatBedrock + ] = LLMFactory.get_chat_model(config.llm) self.predictor = AutogluonTabularPredictor(config.autogluon) self.feature_transformers_config = get_feature_transformers_config(config) @@ -95,13 +104,19 @@ def inference_task(self, task: TabularPredictionTask) -> TabularPredictionTask: ): task = preprocessor.transform(task) except Exception as e: - self.handle_exception(f"Task inference preprocessing: {preprocessor_class}", e) + self.handle_exception( + f"Task inference preprocessing: {preprocessor_class}", e + ) bold_start = "\033[1m" bold_end = "\033[0m" - logger.info(f"{bold_start}Total number of prompt tokens:{bold_end} {self.llm.input_}") - logger.info(f"{bold_start}Total number of completion tokens:{bold_end} {self.llm.output_}") + logger.info( + f"{bold_start}Total number of prompt tokens:{bold_end} {self.llm.input_}" + ) + logger.info( + f"{bold_start}Total number of completion tokens:{bold_end} {self.llm.output_}" + ) logger.info("Task understanding complete!") return task @@ -111,7 +126,9 @@ def preprocess_task(self, task: TabularPredictionTask) -> TabularPredictionTask: task = self.inference_task(task) if self.feature_transformers_config: logger.info("Automatic feature generation starts...") - fe_transformers = [instantiate(ft_config) for ft_config in self.feature_transformers_config] + fe_transformers = [ + instantiate(ft_config) for ft_config in self.feature_transformers_config + ] for fe_transformer in fe_transformers: try: with timeout( @@ -120,7 +137,9 @@ def preprocess_task(self, task: TabularPredictionTask) -> TabularPredictionTask: ): task = fe_transformer.fit_transform(task) except Exception as e: - self.handle_exception(f"Task preprocessing: {fe_transformer.name}", e) + self.handle_exception( + f"Task preprocessing: {fe_transformer.name}", e + ) logger.info("Automatic feature generation complete!") else: logger.info("Automatic feature generation is disabled. ") From 536cccba5a2ed0a34ba94c9b20f6242cb4582a88 Mon Sep 17 00:00:00 2001 From: Alex-Wenner-FHR <78372056+Alex-Wenner-FHR@users.noreply.github.com> Date: Wed, 18 Dec 2024 14:33:37 +0000 Subject: [PATCH 3/9] Adds AssistantAzureChatOpenAI to export --- src/autogluon/assistant/llm/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/autogluon/assistant/llm/__init__.py b/src/autogluon/assistant/llm/__init__.py index e4c3210..7e680c8 100644 --- a/src/autogluon/assistant/llm/__init__.py +++ b/src/autogluon/assistant/llm/__init__.py @@ -1,7 +1,9 @@ -from .llm import AssistantChatBedrock, AssistantChatOpenAI, LLMFactory +from .llm import AssistantChatBedrock, AssistantChatOpenAI, LLMFactory, AssistantAzureChatOpenAI __all__ = [ + "AssistantAzureChatOpenAI", "AssistantChatOpenAI", "AssistantChatBedrock", "LLMFactory", + ] From c9f92f305ac1f7c1348fdba0a34e229e96b04e77 Mon Sep 17 00:00:00 2001 From: Alex-Wenner-FHR <78372056+Alex-Wenner-FHR@users.noreply.github.com> Date: Wed, 18 Dec 2024 14:34:01 +0000 Subject: [PATCH 4/9] adds AzureChatOpenAI langchain and supporting methods to enable azure openai --- src/autogluon/assistant/llm/llm.py | 128 +++++++++++++++++++++++++---- 1 file changed, 112 insertions(+), 16 deletions(-) diff --git a/src/autogluon/assistant/llm/llm.py b/src/autogluon/assistant/llm/llm.py index f35b80f..0cbed32 100644 --- a/src/autogluon/assistant/llm/llm.py +++ b/src/autogluon/assistant/llm/llm.py @@ -7,9 +7,9 @@ import botocore from langchain.schema import AIMessage, BaseMessage from langchain_aws import ChatBedrock -from langchain_openai import ChatOpenAI +from langchain_openai import ChatOpenAI, AzureChatOpenAI from omegaconf import DictConfig -from openai import OpenAI +from openai import OpenAI, AzureOpenAI from pydantic import BaseModel, Field from tenacity import retry, stop_after_attempt, wait_exponential @@ -19,8 +19,47 @@ class AssistantChatOpenAI(ChatOpenAI, BaseModel): + + + history_: List[Dict[str, Any]] = Field(default_factory=list) + input_: int = Field(default=0) + output_: int = Field(default=0) + + def describe(self) -> Dict[str, Any]: + return { + "model": self.model_name, + "proxy": self.openai_proxy, + "history": self.history_, + "prompt_tokens": self.input_, + "completion_tokens": self.output_, + } + + @retry( + stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10) + ) + def invoke(self, *args, **kwargs): + input_: List[BaseMessage] = args[0] + response = super().invoke(*args, **kwargs) + + # Update token usage + if isinstance(response, AIMessage) and response.usage_metadata: + self.input_ += response.usage_metadata.get("input_tokens", 0) + self.output_ += response.usage_metadata.get("output_tokens", 0) + + self.history_.append( + { + "input": [{"type": msg.type, "content": msg.content} for msg in input_], + "output": pprint.pformat(dict(response)), + "prompt_tokens": self.input_, + "completion_tokens": self.output_, + } + ) + return response + + +class AssistantAzureChatOpenAI(AzureChatOpenAI, BaseModel): """ - AssistantChatOpenAI is a subclass of ChatOpenAI that traces the input and output of the model. + Mixin class for common functionalities between ChatOpenAI and AzureChatOpenAI. """ history_: List[Dict[str, Any]] = Field(default_factory=list) @@ -36,7 +75,9 @@ def describe(self) -> Dict[str, Any]: "completion_tokens": self.output_, } - @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10)) + @retry( + stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10) + ) def invoke(self, *args, **kwargs): input_: List[BaseMessage] = args[0] response = super().invoke(*args, **kwargs) @@ -74,7 +115,9 @@ def describe(self) -> Dict[str, Any]: "completion_tokens": self.output_, } - @retry(stop=stop_after_attempt(50), wait=wait_exponential(multiplier=1, min=4, max=10)) + @retry( + stop=stop_after_attempt(50), wait=wait_exponential(multiplier=1, min=4, max=10) + ) def invoke(self, *args, **kwargs): input_: List[BaseMessage] = args[0] try: @@ -107,7 +150,11 @@ def get_openai_models() -> List[str]: try: client = OpenAI() models = client.models.list() - return [model.id for model in models if model.id.startswith(("gpt-3.5", "gpt-4"))] + return [ + model.id + for model in models + if model.id.startswith(("gpt-3.5", "gpt-4")) + ] except Exception as e: print(f"Error fetching OpenAI models: {e}") return [] @@ -123,9 +170,25 @@ def get_bedrock_models() -> List[str]: print(f"Error fetching Bedrock models: {e}") return [] + @staticmethod + def get_azure_models() -> List[str]: + try: + client = AzureOpenAI() + models = client.models.list() + return [ + model.id + for model in models + if model.id.startswith(("gpt-3.5", "gpt-4")) + ] + except Exception as e: + print(f"Error fetching Azure models: {e}") + return [] + @classmethod def get_valid_models(cls, provider): - if provider == "openai": + if provider == "azure": + return cls.get_azure_models() + elif provider == "openai": return cls.get_openai_models() elif provider == "bedrock": model_names = cls.get_bedrock_models() @@ -136,17 +199,40 @@ def get_valid_models(cls, provider): @classmethod def get_valid_providers(cls): - return ["openai", "bedrock"] + return ["azure", "openai", "bedrock"] @staticmethod - def _get_openai_chat_model(config: DictConfig) -> AssistantChatOpenAI: + def _get_azure_chat_model( + config: DictConfig, + ) -> AssistantAzureChatOpenAI: + if "AZURE_OPENAI_API_KEY" in os.environ: + api_key = os.environ["AZURE_OPENAI_API_KEY"] + else: + raise Exception("Azure API env variable AZURE_API_KEY not set") + + logger.info( + f"AGA is using model {config.model} from Azure to assist you with the task." + ) + return AssistantAzureChatOpenAI( + api_key = api_key, + model_name=config.model, + temperature=config.temperature, + max_tokens=config.max_tokens, + verbose=config.verbose, + ) + + @staticmethod + def _get_openai_chat_model( + config: DictConfig, + ) -> AssistantChatOpenAI: if "OPENAI_API_KEY" in os.environ: api_key = os.environ["OPENAI_API_KEY"] else: raise Exception("OpenAI API env variable OPENAI_API_KEY not set") - logger.info(f"AGA is using model {config.model} from OpenAI to assist you with the task.") - + logger.info( + f"AGA is using model {config.model} from OpenAI to assist you with the task." + ) return AssistantChatOpenAI( model_name=config.model, temperature=config.temperature, @@ -158,7 +244,9 @@ def _get_openai_chat_model(config: DictConfig) -> AssistantChatOpenAI: @staticmethod def _get_bedrock_chat_model(config: DictConfig) -> AssistantChatBedrock: - logger.info(f"AGA is using model {config.model} from Bedrock to assist you with the task.") + logger.info( + f"AGA is using model {config.model} from Bedrock to assist you with the task." + ) return AssistantChatBedrock( model_id=config.model, @@ -172,9 +260,13 @@ def _get_bedrock_chat_model(config: DictConfig) -> AssistantChatBedrock: ) @classmethod - def get_chat_model(cls, config: DictConfig) -> Union[AssistantChatOpenAI, AssistantChatBedrock]: + def get_chat_model( + cls, config: DictConfig + ) -> Union[AssistantChatOpenAI, AssistantAzureChatOpenAI, AssistantChatBedrock]: valid_providers = cls.get_valid_providers() - assert config.provider in valid_providers, f"{config.provider} is not a valid provider in: {valid_providers}" + assert ( + config.provider in valid_providers + ), f"{config.provider} is not a valid provider in: {valid_providers}" valid_models = cls.get_valid_models(config.provider) assert ( @@ -182,9 +274,13 @@ def get_chat_model(cls, config: DictConfig) -> Union[AssistantChatOpenAI, Assist ), f"{config.model} is not a valid model in: {valid_models} for provider {config.provider}" if config.model not in WHITE_LIST_LLM: - logger.warning(f"{config.model} is not on the white list. Our white list models include {WHITE_LIST_LLM}") + logger.warning( + f"{config.model} is not on the white list. Our white list models include {WHITE_LIST_LLM}" + ) - if config.provider == "openai": + if config.provider == "azure": + return LLMFactory._get_azure_chat_model(config) + elif config.provider == "openai": return LLMFactory._get_openai_chat_model(config) elif config.provider == "bedrock": return LLMFactory._get_bedrock_chat_model(config) From 9d2f099145ef386d6ee4d973f8d17b7966badc50 Mon Sep 17 00:00:00 2001 From: Alex-Wenner-FHR <78372056+Alex-Wenner-FHR@users.noreply.github.com> Date: Wed, 18 Dec 2024 14:34:13 +0000 Subject: [PATCH 5/9] adds constants to support azure --- src/autogluon/assistant/ui/constants.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/autogluon/assistant/ui/constants.py b/src/autogluon/assistant/ui/constants.py index 3008ede..9e22554 100644 --- a/src/autogluon/assistant/ui/constants.py +++ b/src/autogluon/assistant/ui/constants.py @@ -35,13 +35,15 @@ # LLM configurations LLM_MAPPING = { "Claude 3.5 with Amazon Bedrock": "anthropic.claude-3-5-sonnet-20241022-v2:0", - "GPT 4o": "gpt-4o-2024-08-06", + "GPT 4o with OpenAI": "gpt-4o-2024-08-06", + "GPT 4o with Azure": "gpt-4o-2024-08-06", + } -LLM_OPTIONS = ["Claude 3.5 with Amazon Bedrock", "GPT 4o"] +LLM_OPTIONS = ["Claude 3.5 with Amazon Bedrock", "GPT 4o with OpenAI", "GPT 4o with Azure"] # Provider configuration -PROVIDER_MAPPING = {"Claude 3.5 with Amazon Bedrock": "bedrock", "GPT 4o": "openai"} +PROVIDER_MAPPING = {"Claude 3.5 with Amazon Bedrock": "bedrock", "GPT 4o with OpenAI": "openai", "GPT 4o with Azure": "azure"} INITIAL_STAGE = { "Task Understanding": [], From f9a02ae06f67cc97ec12256e424ea5f706bf0203 Mon Sep 17 00:00:00 2001 From: Alex-Wenner-FHR <78372056+Alex-Wenner-FHR@users.noreply.github.com> Date: Wed, 18 Dec 2024 14:34:23 +0000 Subject: [PATCH 6/9] adds azure to configure_llm --- tools/configure_llms.sh | 189 ++++++---------------------------------- 1 file changed, 26 insertions(+), 163 deletions(-) diff --git a/tools/configure_llms.sh b/tools/configure_llms.sh index 1ec46b5..57190c7 100644 --- a/tools/configure_llms.sh +++ b/tools/configure_llms.sh @@ -22,6 +22,9 @@ tmp_AWS_DEFAULT_REGION="" tmp_AWS_ACCESS_KEY_ID="" tmp_AWS_SECRET_ACCESS_KEY="" tmp_OPENAI_API_KEY="" +tmp_OPENAI_API_VERSION="" +tmp_AZURE_OPENAI_API_KEY="" +tmp_AZURE_OPENAI_ENDPOINT="" # Function to print colored messages print_color() { @@ -70,6 +73,9 @@ read_existing_config() { "AWS_ACCESS_KEY_ID") tmp_AWS_ACCESS_KEY_ID="$value" ;; "AWS_SECRET_ACCESS_KEY") tmp_AWS_SECRET_ACCESS_KEY="$value" ;; "OPENAI_API_KEY") tmp_OPENAI_API_KEY="$value" ;; + "OPENAI_API_VERSION") tmp_OPENAI_API_VERSION="$value" ;; + "AZURE_OPENAI_API_KEY") tmp_AZURE_OPENAI_API_KEY="$value" ;; + "AZURE_OPENAI_ENDPOINT") tmp_AZURE_OPENAI_ENDPOINT="$value" ;; esac fi done < "$CONFIG_FILE" @@ -88,9 +94,14 @@ save_configuration() { tmp_AWS_DEFAULT_REGION="$AWS_DEFAULT_REGION" tmp_AWS_ACCESS_KEY_ID="$AWS_ACCESS_KEY_ID" tmp_AWS_SECRET_ACCESS_KEY="$AWS_SECRET_ACCESS_KEY" - else - # Update OpenAI variable + elif [ "$provider" = "openai" ]; then + # Update OpenAI variables tmp_OPENAI_API_KEY="$OPENAI_API_KEY" + tmp_OPENAI_API_VERSION="$OPENAI_API_VERSION" + elif [ "$provider" = "azure" ]; then + # Update Azure variables + tmp_AZURE_OPENAI_API_KEY="$AZURE_OPENAI_API_KEY" + tmp_AZURE_OPENAI_ENDPOINT="$AZURE_OPENAI_ENDPOINT" fi # Save all configurations @@ -102,6 +113,12 @@ save_configuration() { if [ -n "$tmp_OPENAI_API_KEY" ]; then echo "OPENAI_API_KEY=$tmp_OPENAI_API_KEY" >> "$CONFIG_FILE" + echo "OPENAI_API_VERSION=$tmp_OPENAI_API_VERSION" >> "$CONFIG_FILE" + fi + + if [ -n "$tmp_AZURE_OPENAI_API_KEY" ]; then + echo "AZURE_OPENAI_API_KEY=$tmp_AZURE_OPENAI_API_KEY" >> "$CONFIG_FILE" + echo "AZURE_OPENAI_ENDPOINT=$tmp_AZURE_OPENAI_ENDPOINT" >> "$CONFIG_FILE" fi # Export all variables @@ -113,6 +130,12 @@ save_configuration() { if [ -n "$tmp_OPENAI_API_KEY" ]; then export OPENAI_API_KEY="$tmp_OPENAI_API_KEY" + export OPENAI_API_VERSION="$tmp_OPENAI_API_VERSION" + fi + + if [ -n "$tmp_AZURE_OPENAI_API_KEY" ]; then + export AZURE_OPENAI_API_KEY="$tmp_AZURE_OPENAI_API_KEY" + export AZURE_OPENAI_ENDPOINT="$tmp_AZURE_OPENAI_ENDPOINT" fi # Set proper permissions @@ -129,164 +152,4 @@ check_existing_config() { cp "$CONFIG_FILE" "${CONFIG_FILE}.backup" print_color "$GREEN" "Backup created at ${CONFIG_FILE}.backup" fi -} - -# Function to display current configuration file -display_config() { - if [ ! -f "$CONFIG_FILE" ]; then - print_color "$YELLOW" "No configuration file found at $CONFIG_FILE" - return - fi - - read_existing_config - - print_header "Current Configuration File" - - print_color "$GREEN" "AWS Bedrock Configuration:" - if [ -n "$tmp_AWS_ACCESS_KEY_ID" ]; then - echo "AWS_DEFAULT_REGION=${tmp_AWS_DEFAULT_REGION}" - echo "AWS_ACCESS_KEY_ID=${tmp_AWS_ACCESS_KEY_ID}" - echo "AWS_SECRET_ACCESS_KEY=********" - else - print_color "$YELLOW" "Bedrock is not configured" - fi - - echo - print_color "$GREEN" "OpenAI Configuration:" - if [ -n "$tmp_OPENAI_API_KEY" ]; then - echo "OPENAI_API_KEY=${tmp_OPENAI_API_KEY}" - else - print_color "$YELLOW" "OpenAI is not configured" - fi - echo -} - -# Function to display current environment variables -display_env_vars() { - print_header "Current Environment Variables" - - print_color "$GREEN" "AWS Bedrock Environment Variables:" - echo "AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION:-(not set)}" - echo "AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID:-(not set)}" - if [ -n "$AWS_SECRET_ACCESS_KEY" ]; then - echo "AWS_SECRET_ACCESS_KEY=********" - else - echo "AWS_SECRET_ACCESS_KEY=(not set)" - fi - - echo - print_color "$GREEN" "OpenAI Environment Variables:" - echo "OPENAI_API_KEY=${OPENAI_API_KEY:-(not set)}" - echo - - # Compare configuration with environment variables - if [ -f "$CONFIG_FILE" ]; then - read_existing_config - - local has_mismatch=false - if [ -n "$tmp_AWS_DEFAULT_REGION" ] && [ "$tmp_AWS_DEFAULT_REGION" != "$AWS_DEFAULT_REGION" ]; then - has_mismatch=true - fi - if [ -n "$tmp_AWS_ACCESS_KEY_ID" ] && [ "$tmp_AWS_ACCESS_KEY_ID" != "$AWS_ACCESS_KEY_ID" ]; then - has_mismatch=true - fi - if [ -n "$tmp_AWS_SECRET_ACCESS_KEY" ] && [ "$tmp_AWS_SECRET_ACCESS_KEY" != "$AWS_SECRET_ACCESS_KEY" ]; then - has_mismatch=true - fi - if [ -n "$tmp_OPENAI_API_KEY" ] && [ "$tmp_OPENAI_API_KEY" != "$OPENAI_API_KEY" ]; then - has_mismatch=true - fi - - if [ "$has_mismatch" = true ]; then - print_color "$YELLOW" "Warning: Some environment variables don't match the configuration file." - print_color "$YELLOW" "Run 'source $CONFIG_FILE' to sync them." - fi - fi -} - -# Function to configure providers -configure_provider() { - print_color "$GREEN" "Select your LLM provider to configure:" - echo "1) AWS Bedrock" - echo "2) OpenAI" - echo -n "Enter your choice (1/2): " - read provider_choice - - case $provider_choice in - 1) - print_color "$BLUE" "\nConfiguring AWS Bedrock..." - - while true; do - echo -n "Enter your AWS region (e.g., us-east-1): " - read AWS_DEFAULT_REGION - if validate_aws_region "$AWS_DEFAULT_REGION"; then - break - fi - print_color "$RED" "Invalid AWS region. Please enter a valid region." - done - - echo -n "Enter your AWS Access Key ID: " - read AWS_ACCESS_KEY_ID - echo -n "Enter your AWS Secret Access Key: " - read -s AWS_SECRET_ACCESS_KEY - echo - - save_configuration "bedrock" - ;; - - 2) - print_color "$BLUE" "\nConfiguring OpenAI..." - - while true; do - echo -n "Enter your OpenAI API Key (starts with sk-): " - read OPENAI_API_KEY - if validate_openai_api_key "$OPENAI_API_KEY"; then - break - fi - print_color "$RED" "Invalid OpenAI API Key format. Please try again." - done - - print_color "$RED" "Note: Free-tier OpenAI accounts may be subject to rate limits." - print_color "$RED" "We recommend using a paid OpenAI API key for seamless functionality." - - save_configuration "openai" - ;; - - *) - print_color "$RED" "Invalid choice. Exiting." - return 1 - ;; - esac - - print_color "$GREEN" "\nConfiguration complete!" - display_config -} - -# Main script -clear -print_color "$BLUE" "=== AGA LLM Configuration Tool ===" -echo -print_color "$GREEN" "Select an option:" -echo "1) Configure LLM providers" -echo "2) View current configuration file" -echo "3) View current environment variables" -echo -n "Enter your choice (1/2/3): " -read main_choice - -case $main_choice in - 1) - check_existing_config - read_existing_config - configure_provider - ;; - 2) - display_config - ;; - 3) - display_env_vars - ;; - *) - print_color "$RED" "Invalid choice. Exiting." - return 1 - ;; -esac +} \ No newline at end of file From 087a7c4b8b682c58a3d040a57f21dd3ae875a4e8 Mon Sep 17 00:00:00 2001 From: Alex-Wenner-FHR <78372056+Alex-Wenner-FHR@users.noreply.github.com> Date: Wed, 18 Dec 2024 14:41:59 +0000 Subject: [PATCH 7/9] restores comment and adds comment to follow original standard --- src/autogluon/assistant/llm/llm.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/autogluon/assistant/llm/llm.py b/src/autogluon/assistant/llm/llm.py index 0cbed32..b1f7ca7 100644 --- a/src/autogluon/assistant/llm/llm.py +++ b/src/autogluon/assistant/llm/llm.py @@ -19,7 +19,9 @@ class AssistantChatOpenAI(ChatOpenAI, BaseModel): - + """ + AssistantChatOpenAI is a subclass of ChatOpenAI that traces the input and output of the model. + """ history_: List[Dict[str, Any]] = Field(default_factory=list) input_: int = Field(default=0) @@ -59,7 +61,7 @@ def invoke(self, *args, **kwargs): class AssistantAzureChatOpenAI(AzureChatOpenAI, BaseModel): """ - Mixin class for common functionalities between ChatOpenAI and AzureChatOpenAI. + AssistantAzureChatOpenAI is a subclass of AzureChatOpenAI that traces the input and output of the model. """ history_: List[Dict[str, Any]] = Field(default_factory=list) From 27b165bd76b985ace57235e4f22aca550142b7bc Mon Sep 17 00:00:00 2001 From: Alex-Wenner-FHR <78372056+Alex-Wenner-FHR@users.noreply.github.com> Date: Thu, 2 Jan 2025 14:54:49 +0000 Subject: [PATCH 8/9] executes black, ruff, isort and adds check for required env vars --- src/autogluon/assistant/assistant.py | 37 +++++------------ src/autogluon/assistant/llm/__init__.py | 3 +- src/autogluon/assistant/llm/llm.py | 55 ++++++++----------------- src/autogluon/assistant/ui/constants.py | 7 +++- 4 files changed, 34 insertions(+), 68 deletions(-) diff --git a/src/autogluon/assistant/assistant.py b/src/autogluon/assistant/assistant.py index 02540c9..7ecfff3 100644 --- a/src/autogluon/assistant/assistant.py +++ b/src/autogluon/assistant/assistant.py @@ -8,12 +8,7 @@ from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf -from autogluon.assistant.llm import ( - AssistantChatBedrock, - AssistantChatOpenAI, - AssistantAzureChatOpenAI, - LLMFactory, -) +from autogluon.assistant.llm import AssistantAzureChatOpenAI, AssistantChatBedrock, AssistantChatOpenAI, LLMFactory from .predictor import AutogluonTabularPredictor from .task import TabularPredictionTask @@ -36,9 +31,7 @@ def timeout(seconds: int, error_message: Optional[str] = None): if sys.platform == "win32": # Windows implementation using threading - timer = threading.Timer( - seconds, lambda: (_ for _ in ()).throw(TimeoutError(error_message)) - ) + timer = threading.Timer(seconds, lambda: (_ for _ in ()).throw(TimeoutError(error_message))) timer.start() try: yield @@ -62,9 +55,9 @@ class TabularPredictionAssistant: def __init__(self, config: DictConfig) -> None: self.config = config - self.llm: Union[ - AssistantChatOpenAI, AssistantAzureChatOpenAI, AssistantChatBedrock - ] = LLMFactory.get_chat_model(config.llm) + self.llm: Union[AssistantChatOpenAI, AssistantAzureChatOpenAI, AssistantChatBedrock] = ( + LLMFactory.get_chat_model(config.llm) + ) self.predictor = AutogluonTabularPredictor(config.autogluon) self.feature_transformers_config = get_feature_transformers_config(config) @@ -104,19 +97,13 @@ def inference_task(self, task: TabularPredictionTask) -> TabularPredictionTask: ): task = preprocessor.transform(task) except Exception as e: - self.handle_exception( - f"Task inference preprocessing: {preprocessor_class}", e - ) + self.handle_exception(f"Task inference preprocessing: {preprocessor_class}", e) bold_start = "\033[1m" bold_end = "\033[0m" - logger.info( - f"{bold_start}Total number of prompt tokens:{bold_end} {self.llm.input_}" - ) - logger.info( - f"{bold_start}Total number of completion tokens:{bold_end} {self.llm.output_}" - ) + logger.info(f"{bold_start}Total number of prompt tokens:{bold_end} {self.llm.input_}") + logger.info(f"{bold_start}Total number of completion tokens:{bold_end} {self.llm.output_}") logger.info("Task understanding complete!") return task @@ -126,9 +113,7 @@ def preprocess_task(self, task: TabularPredictionTask) -> TabularPredictionTask: task = self.inference_task(task) if self.feature_transformers_config: logger.info("Automatic feature generation starts...") - fe_transformers = [ - instantiate(ft_config) for ft_config in self.feature_transformers_config - ] + fe_transformers = [instantiate(ft_config) for ft_config in self.feature_transformers_config] for fe_transformer in fe_transformers: try: with timeout( @@ -137,9 +122,7 @@ def preprocess_task(self, task: TabularPredictionTask) -> TabularPredictionTask: ): task = fe_transformer.fit_transform(task) except Exception as e: - self.handle_exception( - f"Task preprocessing: {fe_transformer.name}", e - ) + self.handle_exception(f"Task preprocessing: {fe_transformer.name}", e) logger.info("Automatic feature generation complete!") else: logger.info("Automatic feature generation is disabled. ") diff --git a/src/autogluon/assistant/llm/__init__.py b/src/autogluon/assistant/llm/__init__.py index 7e680c8..d029dc5 100644 --- a/src/autogluon/assistant/llm/__init__.py +++ b/src/autogluon/assistant/llm/__init__.py @@ -1,9 +1,8 @@ -from .llm import AssistantChatBedrock, AssistantChatOpenAI, LLMFactory, AssistantAzureChatOpenAI +from .llm import AssistantAzureChatOpenAI, AssistantChatBedrock, AssistantChatOpenAI, LLMFactory __all__ = [ "AssistantAzureChatOpenAI", "AssistantChatOpenAI", "AssistantChatBedrock", "LLMFactory", - ] diff --git a/src/autogluon/assistant/llm/llm.py b/src/autogluon/assistant/llm/llm.py index b1f7ca7..f054f93 100644 --- a/src/autogluon/assistant/llm/llm.py +++ b/src/autogluon/assistant/llm/llm.py @@ -7,9 +7,9 @@ import botocore from langchain.schema import AIMessage, BaseMessage from langchain_aws import ChatBedrock -from langchain_openai import ChatOpenAI, AzureChatOpenAI +from langchain_openai import AzureChatOpenAI, ChatOpenAI from omegaconf import DictConfig -from openai import OpenAI, AzureOpenAI +from openai import AzureOpenAI, OpenAI from pydantic import BaseModel, Field from tenacity import retry, stop_after_attempt, wait_exponential @@ -36,9 +36,7 @@ def describe(self) -> Dict[str, Any]: "completion_tokens": self.output_, } - @retry( - stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10) - ) + @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10)) def invoke(self, *args, **kwargs): input_: List[BaseMessage] = args[0] response = super().invoke(*args, **kwargs) @@ -77,9 +75,7 @@ def describe(self) -> Dict[str, Any]: "completion_tokens": self.output_, } - @retry( - stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10) - ) + @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10)) def invoke(self, *args, **kwargs): input_: List[BaseMessage] = args[0] response = super().invoke(*args, **kwargs) @@ -117,9 +113,7 @@ def describe(self) -> Dict[str, Any]: "completion_tokens": self.output_, } - @retry( - stop=stop_after_attempt(50), wait=wait_exponential(multiplier=1, min=4, max=10) - ) + @retry(stop=stop_after_attempt(50), wait=wait_exponential(multiplier=1, min=4, max=10)) def invoke(self, *args, **kwargs): input_: List[BaseMessage] = args[0] try: @@ -152,11 +146,7 @@ def get_openai_models() -> List[str]: try: client = OpenAI() models = client.models.list() - return [ - model.id - for model in models - if model.id.startswith(("gpt-3.5", "gpt-4")) - ] + return [model.id for model in models if model.id.startswith(("gpt-3.5", "gpt-4"))] except Exception as e: print(f"Error fetching OpenAI models: {e}") return [] @@ -177,11 +167,7 @@ def get_azure_models() -> List[str]: try: client = AzureOpenAI() models = client.models.list() - return [ - model.id - for model in models - if model.id.startswith(("gpt-3.5", "gpt-4")) - ] + return [model.id for model in models if model.id.startswith(("gpt-3.5", "gpt-4"))] except Exception as e: print(f"Error fetching Azure models: {e}") return [] @@ -212,11 +198,14 @@ def _get_azure_chat_model( else: raise Exception("Azure API env variable AZURE_API_KEY not set") - logger.info( - f"AGA is using model {config.model} from Azure to assist you with the task." - ) + if "OPENAI_API_VERSION" not in os.environ: + raise Exception("Azure API env variable OPENAI_API_VERSION not set") + if "AZURE_OPENAI_ENDPOINT" not in os.environ: + raise Exception("Azure API env variable AZURE_OPENAI_ENDPOINT not set") + + logger.info(f"AGA is using model {config.model} from Azure to assist you with the task.") return AssistantAzureChatOpenAI( - api_key = api_key, + api_key=api_key, model_name=config.model, temperature=config.temperature, max_tokens=config.max_tokens, @@ -232,9 +221,7 @@ def _get_openai_chat_model( else: raise Exception("OpenAI API env variable OPENAI_API_KEY not set") - logger.info( - f"AGA is using model {config.model} from OpenAI to assist you with the task." - ) + logger.info(f"AGA is using model {config.model} from OpenAI to assist you with the task.") return AssistantChatOpenAI( model_name=config.model, temperature=config.temperature, @@ -246,9 +233,7 @@ def _get_openai_chat_model( @staticmethod def _get_bedrock_chat_model(config: DictConfig) -> AssistantChatBedrock: - logger.info( - f"AGA is using model {config.model} from Bedrock to assist you with the task." - ) + logger.info(f"AGA is using model {config.model} from Bedrock to assist you with the task.") return AssistantChatBedrock( model_id=config.model, @@ -266,9 +251,7 @@ def get_chat_model( cls, config: DictConfig ) -> Union[AssistantChatOpenAI, AssistantAzureChatOpenAI, AssistantChatBedrock]: valid_providers = cls.get_valid_providers() - assert ( - config.provider in valid_providers - ), f"{config.provider} is not a valid provider in: {valid_providers}" + assert config.provider in valid_providers, f"{config.provider} is not a valid provider in: {valid_providers}" valid_models = cls.get_valid_models(config.provider) assert ( @@ -276,9 +259,7 @@ def get_chat_model( ), f"{config.model} is not a valid model in: {valid_models} for provider {config.provider}" if config.model not in WHITE_LIST_LLM: - logger.warning( - f"{config.model} is not on the white list. Our white list models include {WHITE_LIST_LLM}" - ) + logger.warning(f"{config.model} is not on the white list. Our white list models include {WHITE_LIST_LLM}") if config.provider == "azure": return LLMFactory._get_azure_chat_model(config) diff --git a/src/autogluon/assistant/ui/constants.py b/src/autogluon/assistant/ui/constants.py index 9e22554..973c2af 100644 --- a/src/autogluon/assistant/ui/constants.py +++ b/src/autogluon/assistant/ui/constants.py @@ -37,13 +37,16 @@ "Claude 3.5 with Amazon Bedrock": "anthropic.claude-3-5-sonnet-20241022-v2:0", "GPT 4o with OpenAI": "gpt-4o-2024-08-06", "GPT 4o with Azure": "gpt-4o-2024-08-06", - } LLM_OPTIONS = ["Claude 3.5 with Amazon Bedrock", "GPT 4o with OpenAI", "GPT 4o with Azure"] # Provider configuration -PROVIDER_MAPPING = {"Claude 3.5 with Amazon Bedrock": "bedrock", "GPT 4o with OpenAI": "openai", "GPT 4o with Azure": "azure"} +PROVIDER_MAPPING = { + "Claude 3.5 with Amazon Bedrock": "bedrock", + "GPT 4o with OpenAI": "openai", + "GPT 4o with Azure": "azure", +} INITIAL_STAGE = { "Task Understanding": [], From ecd223a16f58e5787c54a515cec783fd35972e99 Mon Sep 17 00:00:00 2001 From: Haoyang Fang <107515844+FANGAreNotGnu@users.noreply.github.com> Date: Mon, 6 Jan 2025 16:00:49 -0800 Subject: [PATCH 9/9] Update configure_llms.sh --- tools/configure_llms.sh | 275 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 254 insertions(+), 21 deletions(-) diff --git a/tools/configure_llms.sh b/tools/configure_llms.sh index 57190c7..83dde0f 100644 --- a/tools/configure_llms.sh +++ b/tools/configure_llms.sh @@ -22,8 +22,8 @@ tmp_AWS_DEFAULT_REGION="" tmp_AWS_ACCESS_KEY_ID="" tmp_AWS_SECRET_ACCESS_KEY="" tmp_OPENAI_API_KEY="" -tmp_OPENAI_API_VERSION="" tmp_AZURE_OPENAI_API_KEY="" +tmp_OPENAI_API_VERSION="" tmp_AZURE_OPENAI_ENDPOINT="" # Function to print colored messages @@ -63,6 +63,20 @@ validate_openai_api_key() { return 1 } +# Function to validate Azure OpenAI endpoint +validate_azure_endpoint() { + local endpoint=$1 + [[ $endpoint =~ ^https://[a-zA-Z0-9-]+\.openai\.azure\.com/?$ ]] && return 0 + return 1 +} + +# Function to validate API version +validate_api_version() { + local version=$1 + [[ $version =~ ^[0-9]{4}-[0-9]{2}-[0-9]{2}$ ]] && return 0 + return 1 +} + # Function to read existing configuration into temporary variables read_existing_config() { if [ -f "$CONFIG_FILE" ]; then @@ -73,8 +87,8 @@ read_existing_config() { "AWS_ACCESS_KEY_ID") tmp_AWS_ACCESS_KEY_ID="$value" ;; "AWS_SECRET_ACCESS_KEY") tmp_AWS_SECRET_ACCESS_KEY="$value" ;; "OPENAI_API_KEY") tmp_OPENAI_API_KEY="$value" ;; - "OPENAI_API_VERSION") tmp_OPENAI_API_VERSION="$value" ;; "AZURE_OPENAI_API_KEY") tmp_AZURE_OPENAI_API_KEY="$value" ;; + "OPENAI_API_VERSION") tmp_OPENAI_API_VERSION="$value" ;; "AZURE_OPENAI_ENDPOINT") tmp_AZURE_OPENAI_ENDPOINT="$value" ;; esac fi @@ -89,20 +103,21 @@ save_configuration() { # Create or truncate the config file echo "" > "$CONFIG_FILE" || { print_color "$RED" "Error: Cannot write to '$CONFIG_FILE'"; return 1; } - if [ "$provider" = "bedrock" ]; then - # Update AWS variables - tmp_AWS_DEFAULT_REGION="$AWS_DEFAULT_REGION" - tmp_AWS_ACCESS_KEY_ID="$AWS_ACCESS_KEY_ID" - tmp_AWS_SECRET_ACCESS_KEY="$AWS_SECRET_ACCESS_KEY" - elif [ "$provider" = "openai" ]; then - # Update OpenAI variables - tmp_OPENAI_API_KEY="$OPENAI_API_KEY" - tmp_OPENAI_API_VERSION="$OPENAI_API_VERSION" - elif [ "$provider" = "azure" ]; then - # Update Azure variables - tmp_AZURE_OPENAI_API_KEY="$AZURE_OPENAI_API_KEY" - tmp_AZURE_OPENAI_ENDPOINT="$AZURE_OPENAI_ENDPOINT" - fi + case "$provider" in + "bedrock") + tmp_AWS_DEFAULT_REGION="$AWS_DEFAULT_REGION" + tmp_AWS_ACCESS_KEY_ID="$AWS_ACCESS_KEY_ID" + tmp_AWS_SECRET_ACCESS_KEY="$AWS_SECRET_ACCESS_KEY" + ;; + "openai") + tmp_OPENAI_API_KEY="$OPENAI_API_KEY" + ;; + "azure") + tmp_AZURE_OPENAI_API_KEY="$AZURE_OPENAI_API_KEY" + tmp_OPENAI_API_VERSION="$OPENAI_API_VERSION" + tmp_AZURE_OPENAI_ENDPOINT="$AZURE_OPENAI_ENDPOINT" + ;; + esac # Save all configurations if [ -n "$tmp_AWS_ACCESS_KEY_ID" ]; then @@ -113,11 +128,11 @@ save_configuration() { if [ -n "$tmp_OPENAI_API_KEY" ]; then echo "OPENAI_API_KEY=$tmp_OPENAI_API_KEY" >> "$CONFIG_FILE" - echo "OPENAI_API_VERSION=$tmp_OPENAI_API_VERSION" >> "$CONFIG_FILE" fi - + if [ -n "$tmp_AZURE_OPENAI_API_KEY" ]; then echo "AZURE_OPENAI_API_KEY=$tmp_AZURE_OPENAI_API_KEY" >> "$CONFIG_FILE" + echo "OPENAI_API_VERSION=$tmp_OPENAI_API_VERSION" >> "$CONFIG_FILE" echo "AZURE_OPENAI_ENDPOINT=$tmp_AZURE_OPENAI_ENDPOINT" >> "$CONFIG_FILE" fi @@ -130,11 +145,11 @@ save_configuration() { if [ -n "$tmp_OPENAI_API_KEY" ]; then export OPENAI_API_KEY="$tmp_OPENAI_API_KEY" - export OPENAI_API_VERSION="$tmp_OPENAI_API_VERSION" fi - + if [ -n "$tmp_AZURE_OPENAI_API_KEY" ]; then export AZURE_OPENAI_API_KEY="$tmp_AZURE_OPENAI_API_KEY" + export OPENAI_API_VERSION="$tmp_OPENAI_API_VERSION" export AZURE_OPENAI_ENDPOINT="$tmp_AZURE_OPENAI_ENDPOINT" fi @@ -152,4 +167,222 @@ check_existing_config() { cp "$CONFIG_FILE" "${CONFIG_FILE}.backup" print_color "$GREEN" "Backup created at ${CONFIG_FILE}.backup" fi -} \ No newline at end of file +} + +# Function to display current configuration file +display_config() { + if [ ! -f "$CONFIG_FILE" ]; then + print_color "$YELLOW" "No configuration file found at $CONFIG_FILE" + return + fi + + read_existing_config + + print_header "Current Configuration File" + + print_color "$GREEN" "AWS Bedrock Configuration:" + if [ -n "$tmp_AWS_ACCESS_KEY_ID" ]; then + echo "AWS_DEFAULT_REGION=${tmp_AWS_DEFAULT_REGION}" + echo "AWS_ACCESS_KEY_ID=${tmp_AWS_ACCESS_KEY_ID}" + echo "AWS_SECRET_ACCESS_KEY=********" + else + print_color "$YELLOW" "Bedrock is not configured" + fi + + echo + print_color "$GREEN" "OpenAI Configuration:" + if [ -n "$tmp_OPENAI_API_KEY" ]; then + echo "OPENAI_API_KEY=${tmp_OPENAI_API_KEY}" + else + print_color "$YELLOW" "OpenAI is not configured" + fi + + echo + print_color "$GREEN" "Azure OpenAI Configuration:" + if [ -n "$tmp_AZURE_OPENAI_API_KEY" ]; then + echo "AZURE_OPENAI_API_KEY=********" + echo "OPENAI_API_VERSION=${tmp_OPENAI_API_VERSION}" + echo "AZURE_OPENAI_ENDPOINT=${tmp_AZURE_OPENAI_ENDPOINT}" + else + print_color "$YELLOW" "Azure OpenAI is not configured" + fi + echo +} + +# Function to display current environment variables +display_env_vars() { + print_header "Current Environment Variables" + + print_color "$GREEN" "AWS Bedrock Environment Variables:" + echo "AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION:-(not set)}" + echo "AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID:-(not set)}" + if [ -n "$AWS_SECRET_ACCESS_KEY" ]; then + echo "AWS_SECRET_ACCESS_KEY=********" + else + echo "AWS_SECRET_ACCESS_KEY=(not set)" + fi + + echo + print_color "$GREEN" "OpenAI Environment Variables:" + echo "OPENAI_API_KEY=${OPENAI_API_KEY:-(not set)}" + + echo + print_color "$GREEN" "Azure OpenAI Environment Variables:" + if [ -n "$AZURE_OPENAI_API_KEY" ]; then + echo "AZURE_OPENAI_API_KEY=********" + else + echo "AZURE_OPENAI_API_KEY=(not set)" + fi + echo "OPENAI_API_VERSION=${OPENAI_API_VERSION:-(not set)}" + echo "AZURE_OPENAI_ENDPOINT=${AZURE_OPENAI_ENDPOINT:-(not set)}" + echo + + # Compare configuration with environment variables + if [ -f "$CONFIG_FILE" ]; then + read_existing_config + + local has_mismatch=false + if [ -n "$tmp_AWS_DEFAULT_REGION" ] && [ "$tmp_AWS_DEFAULT_REGION" != "$AWS_DEFAULT_REGION" ]; then + has_mismatch=true + fi + if [ -n "$tmp_AWS_ACCESS_KEY_ID" ] && [ "$tmp_AWS_ACCESS_KEY_ID" != "$AWS_ACCESS_KEY_ID" ]; then + has_mismatch=true + fi + if [ -n "$tmp_AWS_SECRET_ACCESS_KEY" ] && [ "$tmp_AWS_SECRET_ACCESS_KEY" != "$AWS_SECRET_ACCESS_KEY" ]; then + has_mismatch=true + fi + if [ -n "$tmp_OPENAI_API_KEY" ] && [ "$tmp_OPENAI_API_KEY" != "$OPENAI_API_KEY" ]; then + has_mismatch=true + fi + if [ -n "$tmp_AZURE_OPENAI_API_KEY" ] && [ "$tmp_AZURE_OPENAI_API_KEY" != "$AZURE_OPENAI_API_KEY" ]; then + has_mismatch=true + fi + if [ -n "$tmp_OPENAI_API_VERSION" ] && [ "$tmp_OPENAI_API_VERSION" != "$OPENAI_API_VERSION" ]; then + has_mismatch=true + fi + if [ -n "$tmp_AZURE_OPENAI_ENDPOINT" ] && [ "$tmp_AZURE_OPENAI_ENDPOINT" != "$AZURE_OPENAI_ENDPOINT" ]; then + has_mismatch=true + fi + + if [ "$has_mismatch" = true ]; then + print_color "$YELLOW" "Warning: Some environment variables don't match the configuration file." + print_color "$YELLOW" "Run 'source $CONFIG_FILE' to sync them." + fi + fi +} + +# Function to configure providers +configure_provider() { + print_color "$GREEN" "Select your LLM provider to configure:" + echo "1) AWS Bedrock" + echo "2) OpenAI" + echo "3) Azure OpenAI" + echo -n "Enter your choice (1/2/3): " + read provider_choice + + case $provider_choice in + 1) + print_color "$BLUE" "\nConfiguring AWS Bedrock..." + + while true; do + echo -n "Enter your AWS region (e.g., us-east-1): " + read AWS_DEFAULT_REGION + if validate_aws_region "$AWS_DEFAULT_REGION"; then + break + fi + print_color "$RED" "Invalid AWS region. Please enter a valid region." + done + + echo -n "Enter your AWS Access Key ID: " + read AWS_ACCESS_KEY_ID + echo -n "Enter your AWS Secret Access Key: " + read -s AWS_SECRET_ACCESS_KEY + echo + + save_configuration "bedrock" + ;; + + 2) + print_color "$BLUE" "\nConfiguring OpenAI..." + + while true; do + echo -n "Enter your OpenAI API Key (starts with sk-): " + read OPENAI_API_KEY + if validate_openai_api_key "$OPENAI_API_KEY"; then + break + fi + print_color "$RED" "Invalid OpenAI API Key format. Please try again." + done + + print_color "$RED" "Note: Free-tier OpenAI accounts may be subject to rate limits." + print_color "$RED" "We recommend using a paid OpenAI API key for seamless functionality." + + save_configuration "openai" + ;; + + 3) + print_color "$BLUE" "\nConfiguring Azure OpenAI..." + + echo -n "Enter your Azure OpenAI API Key: " + read -s AZURE_OPENAI_API_KEY + echo + + while true; do + echo -n "Enter the API version (YYYY-MM-DD format): " + read OPENAI_API_VERSION + if validate_api_version "$OPENAI_API_VERSION"; then + break + fi + print_color "$RED" "Invalid API version format. Please use YYYY-MM-DD format." + done + + while true; do + echo -n "Enter your Azure OpenAI endpoint (https://.openai.azure.com): " + read AZURE_OPENAI_ENDPOINT + if validate_azure_endpoint "$AZURE_OPENAI_ENDPOINT"; then + break + fi + print_color "$RED" "Invalid endpoint format. Please enter a valid Azure OpenAI endpoint." + done + + save_configuration "azure" + ;; + + *) + print_color "$RED" "Invalid choice. Exiting." + return 1 + ;; + esac + + print_color "$GREEN" "\nConfiguration complete!" + display_config +} + +# Main script +clear +print_color "$BLUE" "=== AGA LLM Configuration Tool ===" +echo +print_color "$GREEN" "Select an option:" +echo "1) Configure LLM providers" +echo "2) View current configuration file" +echo "3) View current environment variables" +echo -n "Enter your choice (1/2/3): " +read main_choice + +case $main_choice in + 1) + check_existing_config + read_existing_config + configure_provider + ;; + 2) + display_config + ;; + 3) + display_env_vars + ;; + *) + print_color "$RED" "Invalid choice. Exiting." + return 1 + ;; +esac