Skip to content

Commit

Permalink
Support managed_identity_client_id for DefaultAzureCredential
Browse files Browse the repository at this point in the history
  • Loading branch information
you-n-g committed Jul 1, 2024
1 parent 40b8250 commit 8c27cf9
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
2 changes: 2 additions & 0 deletions rdagent/core/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@


class RDAgentSettings(BaseSettings):
# TODO: (xiao) I think most of the config should be in oai.config
use_azure: bool = True
use_azure_token_provider: bool = False
managed_identity_client_id: str | None = None
max_retry: int = 10
retry_wait_seconds: int = 1
dump_chat_cache: bool = False
Expand Down
6 changes: 5 additions & 1 deletion rdagent/oai/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def __init__( # noqa: C901, PLR0912, PLR0915
else:
self.use_azure = self.cfg.use_azure
self.use_azure_token_provider = self.cfg.use_azure_token_provider
self.managed_identity_client_id = self.cfg.managed_identity_client_id

self.chat_api_key = self.cfg.chat_openai_api_key if chat_api_key is None else chat_api_key
self.chat_model = self.cfg.chat_model if chat_model is None else chat_model
Expand All @@ -314,7 +315,10 @@ def __init__( # noqa: C901, PLR0912, PLR0915

if self.use_azure:
if self.use_azure_token_provider:
credential = DefaultAzureCredential()
dac_kwargs = {}
if self.managed_identity_client_id is not None:
dac_kwargs["managed_identity_client_id"] = self.managed_identity_client_id
credential = DefaultAzureCredential(**dac_kwargs)
token_provider = get_bearer_token_provider(
credential,
"https://cognitiveservices.azure.com/.default",
Expand Down

0 comments on commit 8c27cf9

Please sign in to comment.