From 8408dbd4a5794ca2ab3e6139cad323e1bd84a738 Mon Sep 17 00:00:00 2001 From: Young Date: Sun, 29 Sep 2024 08:14:46 +0000 Subject: [PATCH] more comments --- rdagent/oai/backends/az.py | 69 ++++++++++++++++++++++++++++++++++++++ rdagent/oai/llm_utils.py | 10 +++--- 2 files changed, 75 insertions(+), 4 deletions(-) create mode 100644 rdagent/oai/backends/az.py diff --git a/rdagent/oai/backends/az.py b/rdagent/oai/backends/az.py new file mode 100644 index 00000000..0954d9ae --- /dev/null +++ b/rdagent/oai/backends/az.py @@ -0,0 +1,69 @@ +""" +TODO: +It is not complete now. + +Please refer to rdagent/oai/llm_utils.py:APIBackend for the future design +""" + +from azure.identity import DefaultAzureCredential, get_bearer_token_provider +import openai +from pydantic_settings import BaseSettings + + +class AzureConf(BaseSettings): + """ + TODO: move more settings here + """ + use_azure_token_provider: bool = False + managed_identity_client_id: str | None = None + chat_model: str = "gpt-4-turbo" + + chat_azure_api_base: str = "" + chat_azure_api_version: str = "" + + +class BaseAPI: + """ + TOOD: there may be some more shared methods in the BaseAPI + """ + pass + + +class AzureAPI(BaseAPI): + + def _get_credential(self): + dac_kwargs = {} + if AZURE_CONF.managed_identity_client_id is not None: + dac_kwargs["managed_identity_client_id"] = self.managed_identity_client_id + credential = DefaultAzureCredential(**dac_kwargs) + return credential + + def _get_client(self): + kwargs = {} + if AZURE_CONF.use_azure_token_provider: + kwargs["azure_ad_token_provider"]= get_bearer_token_provider( + self._get_credential(), + "https://cognitiveservices.azure.com/.default", + ) + return openai.AzureOpenAI( + api_version=AZURE_CONF.chat_azure_api_version, + azure_endpoint=AZURE_CONF.chat_azure_api_base, + **kwargs, + ) + + # def list_deployments(self): + # client = self._get_client() + # try: + # deployments = client.deployments.list() + # return [deployment for deployment in deployments] + # except Exception as e: + # print(f"An error occurred while listing deployments: {e}") + # return [] + +AZURE_CONF = AzureConf() + + +# if __name__ == "__main__": +# api = AzureAPI() +# deployments = api.list_deployments() +# print(deployments) diff --git a/rdagent/oai/llm_utils.py b/rdagent/oai/llm_utils.py index cb161fea..44e6aa2a 100644 --- a/rdagent/oai/llm_utils.py +++ b/rdagent/oai/llm_utils.py @@ -235,6 +235,12 @@ def display_history(self) -> None: class APIBackend: + """ + This is a unified interface for different backends. + + (xiao) thinks integerate all kinds of API in a single class is not a good design. + So we should split them into different classes in `oai/backends/` in the future. + """ # FIXME: (xiao) I think we should skip using self.xxxx # We can use self.cfg directly. If it is hard to 兼容 different settings of backends. We can split it into multiple BaseSettings. def __init__( # noqa: C901, PLR0912, PLR0915 @@ -384,10 +390,6 @@ def __init__( # noqa: C901, PLR0912, PLR0915 self.use_gcr_endpoint = self.cfg.use_gcr_endpoint self.retry_wait_seconds = self.cfg.retry_wait_seconds - def list_available_deployments(self): - if self.use_azure: - # TODO: - def build_chat_session( self, conversation_id: str | None = None,