Skip to content

Commit

Permalink
more comments
Browse files Browse the repository at this point in the history
  • Loading branch information
you-n-g committed Sep 29, 2024
1 parent 1fb9352 commit 8408dbd
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 4 deletions.
69 changes: 69 additions & 0 deletions rdagent/oai/backends/az.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""
TODO:
It is not complete now.
Please refer to rdagent/oai/llm_utils.py:APIBackend for the future design
"""

from azure.identity import DefaultAzureCredential, get_bearer_token_provider
import openai
from pydantic_settings import BaseSettings


class AzureConf(BaseSettings):
"""
TODO: move more settings here
"""
use_azure_token_provider: bool = False
managed_identity_client_id: str | None = None
chat_model: str = "gpt-4-turbo"

chat_azure_api_base: str = ""
chat_azure_api_version: str = ""


class BaseAPI:
"""
TOOD: there may be some more shared methods in the BaseAPI
"""
pass


class AzureAPI(BaseAPI):

def _get_credential(self):
dac_kwargs = {}
if AZURE_CONF.managed_identity_client_id is not None:
dac_kwargs["managed_identity_client_id"] = self.managed_identity_client_id
credential = DefaultAzureCredential(**dac_kwargs)
return credential

def _get_client(self):
kwargs = {}
if AZURE_CONF.use_azure_token_provider:
kwargs["azure_ad_token_provider"]= get_bearer_token_provider(
self._get_credential(),
"https://cognitiveservices.azure.com/.default",
)
return openai.AzureOpenAI(
api_version=AZURE_CONF.chat_azure_api_version,
azure_endpoint=AZURE_CONF.chat_azure_api_base,
**kwargs,
)

# def list_deployments(self):
# client = self._get_client()
# try:
# deployments = client.deployments.list()
# return [deployment for deployment in deployments]
# except Exception as e:
# print(f"An error occurred while listing deployments: {e}")
# return []

AZURE_CONF = AzureConf()


# if __name__ == "__main__":
# api = AzureAPI()
# deployments = api.list_deployments()
# print(deployments)
10 changes: 6 additions & 4 deletions rdagent/oai/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,12 @@ def display_history(self) -> None:


class APIBackend:
"""
This is a unified interface for different backends.
(xiao) thinks integerate all kinds of API in a single class is not a good design.
So we should split them into different classes in `oai/backends/` in the future.
"""
# FIXME: (xiao) I think we should skip using self.xxxx
# We can use self.cfg directly. If it is hard to 兼容 different settings of backends. We can split it into multiple BaseSettings.
def __init__( # noqa: C901, PLR0912, PLR0915
Expand Down Expand Up @@ -384,10 +390,6 @@ def __init__( # noqa: C901, PLR0912, PLR0915
self.use_gcr_endpoint = self.cfg.use_gcr_endpoint
self.retry_wait_seconds = self.cfg.retry_wait_seconds

def list_available_deployments(self):
if self.use_azure:
# TODO:

def build_chat_session(
self,
conversation_id: str | None = None,
Expand Down

0 comments on commit 8408dbd

Please sign in to comment.