Skip to content

Commit

Permalink
Add support for Azure authentication using Azure AD Tokens (#165)
Browse files Browse the repository at this point in the history
* support Azure AD token

* support Azure AD token

* support Azure AD token

* support Azure AD token

* support Azure AD token

* linting

* apply suggestion

* add copyrights

* fix client

* increase version

* fix mypy

* implement workaround to make azure_endpoint a positional argument

* linting

* Update README.md

---------

Co-authored-by: Alaeddine Abdessalem <[email protected]>
Co-authored-by: Philip May <[email protected]>
  • Loading branch information
3 people authored Jul 3, 2024
1 parent 0a59486 commit 7c19235
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 12 deletions.
1 change: 1 addition & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ MIT License

Copyright (c) 2023-2024 Philip May
Copyright (c) 2023-2024 Philip May, Deutsche Telekom AG
Copyright (c) 2023-2024 Alaeddine Abdessalem, Deutsche Telekom AG

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ To install those module specific dependencies see
## Licensing

Copyright (c) 2023-2024 [Philip May](https://philipmay.org)\
Copyright (c) 2023-2024 [Philip May](https://philipmay.org), [Deutsche Telekom AG](https://www.telekom.de/)
Copyright (c) 2023-2024 [Philip May](https://philipmay.org), [Deutsche Telekom AG](https://www.telekom.de/)\
Copyright (c) 2023-2024 Alaeddine Abdessalem, [Deutsche Telekom AG](https://www.telekom.de/)

Licensed under the **MIT License** (the "License"); you may not use this file except in compliance with the License.
You may obtain a copy of the License by reviewing the file
Expand Down
58 changes: 48 additions & 10 deletions mltb2/openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2023-2024 Philip May
# Copyright (c) 2024 Philip May, Deutsche Telekom AG
# Copyright (c) 2024 Alaeddine Abdessalem, Deutsche Telekom AG
# This software is distributed under the terms of the MIT license
# which is available at https://opensource.org/licenses/MIT

Expand Down Expand Up @@ -171,37 +172,40 @@ class OpenAiChat:
model: The OpenAI model name.
"""

api_key: str
model: str
client: Union[OpenAI, AzureOpenAI] = field(init=False, repr=False)
async_client: Union[AsyncOpenAI, AsyncAzureOpenAI] = field(init=False, repr=False)
api_key: Optional[str] = None

def __post_init__(self) -> None:
"""Do post init."""
self.client = OpenAI(api_key=self.api_key)
self.async_client = AsyncOpenAI(api_key=self.api_key)

@classmethod
def from_yaml(cls, yaml_file):
def from_yaml(cls, yaml_file, api_key: Optional[str] = None, **kwargs):
"""Construct this class from a yaml file.
If the ``api_key`` is not set in the yaml file,
it will be loaded from the environment variable ``OPENAI_API_KEY``.
Args:
yaml_file: The yaml file.
api_key: The OpenAI API key.
kwargs: extra kwargs to override parameters
Returns:
The constructed class.
"""
with open(yaml_file, "r") as file:
completion_kwargs = yaml.safe_load(file)

# load api_key from environment variable if it is not set in the yaml file
if "api_key" not in completion_kwargs:
api_key = os.getenv("OPENAI_API_KEY")
if api_key is not None:
completion_kwargs["api_key"] = api_key
# set api_key according to this priority:
# method parameter > yaml > environment variable
api_key = api_key or completion_kwargs.get("api_key") or os.getenv("OPENAI_API_KEY")
completion_kwargs["api_key"] = api_key

if kwargs:
completion_kwargs.update(kwargs)
return cls(**completion_kwargs)

def create_completions(
Expand Down Expand Up @@ -323,8 +327,16 @@ async def create_completions_async(
return result


# there is a limitation with python dataclasses when it comes to defining a subclass with positional arguments, while
# the parent class already defines keyword arguemnts (positional arguments cannot follow keyword arguments)
# workaroung is defined here: https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
@dataclass
class OpenAiAzureChat(OpenAiChat):
class _OpenAiAzureChatBase:
azure_endpoint: str


@dataclass
class OpenAiAzureChat(OpenAiChat, _OpenAiAzureChatBase):
"""Tool to interact with Azure OpenAI chat models.
This can also be constructed with :meth:`~OpenAiChat.from_yaml`.
Expand All @@ -341,18 +353,44 @@ class OpenAiAzureChat(OpenAiChat):
azure_endpoint: The Azure endpoint.
"""

api_version: str
azure_endpoint: str
api_version: Optional[str] = None
api_key: Optional[str] = None
azure_ad_token: Optional[str] = None

def __post_init__(self) -> None:
"""Do post init."""
self.client = AzureOpenAI(
api_key=self.api_key,
api_version=self.api_version,
azure_endpoint=self.azure_endpoint,
azure_ad_token=self.azure_ad_token,
)
self.async_client = AsyncAzureOpenAI(
api_key=self.api_key,
api_version=self.api_version,
azure_endpoint=self.azure_endpoint,
azure_ad_token=self.azure_ad_token,
)

@classmethod
def from_yaml(cls, yaml_file, api_key: Optional[str] = None, azure_ad_token: Optional[str] = None, **kwargs):
"""Construct this class from a yaml file.
If the ``api_key`` is not set in the yaml file,
it will be loaded from the environment variable ``OPENAI_API_KEY``.
Args:
yaml_file: The yaml file.
api_key: The OpenAI API key.
azure_ad_token: Azure AD token
kwargs: extra kwargs to override parameters
Returns:
The constructed class.
"""
with open(yaml_file, "r") as file:
completion_kwargs = yaml.safe_load(file)

# set azure_ad_token according to this priority:
# method parameter > yaml > environment variable
azure_ad_token = azure_ad_token or completion_kwargs.get("AZURE_AD_TOKEN") or os.getenv("AZURE_AD_TOKEN")
return super().from_yaml(yaml_file, api_key=api_key, azure_ad_token=azure_ad_token, **kwargs)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "mltb2"
version = "1.0.1rc2"
version = "1.0.1rc3"
description = "Machine Learning Toolbox 2"
authors = ["PhilipMay <[email protected]>"]
readme = "README.md"
Expand Down

0 comments on commit 7c19235

Please sign in to comment.