Skip to content

Commit

Permalink
Litellm dev 12 30 2024 p1 (BerriAI#7480)
Browse files Browse the repository at this point in the history
* test(azure_openai_o1.py): initial commit with testing for azure openai o1 preview model

* fix(base_llm_unit_tests.py): handle azure o1 preview response format tests

skip as o1 on azure doesn't support tool calling yet

* fix: initial commit of azure o1 handler using openai caller

simplifies calling + allows fake streaming logic alr. implemented for openai to just work

* feat(azure/o1_handler.py): fake o1 streaming for azure o1 models

azure does not currently support streaming for o1

* feat(o1_transformation.py): support overriding 'should_fake_stream' on azure/o1 via 'supports_native_streaming' param on model info

enables user to toggle on when azure allows o1 streaming without needing to bump versions

* style(router.py): remove 'give feedback/get help' messaging when router is used

Prevents noisy messaging

Closes BerriAI#5942

* test: fix azure o1 test

* test: fix tests

* fix: fix test
  • Loading branch information
krrishdholakia authored Dec 31, 2024
1 parent 60bdfb4 commit 347779b
Show file tree
Hide file tree
Showing 17 changed files with 274 additions and 142 deletions.
126 changes: 39 additions & 87 deletions litellm/llms/azure/chat/o1_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,96 +4,48 @@
Written separately to handle faking streaming for o1 models.
"""

import asyncio
from typing import Any, Callable, List, Optional, Union
from typing import Optional, Union

from httpx._config import Timeout
import httpx
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI

from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
from litellm.types.utils import ModelResponse
from litellm.utils import CustomStreamWrapper
from ...openai.openai import OpenAIChatCompletion
from ..common_utils import get_azure_openai_client

from ..azure import AzureChatCompletion


class AzureOpenAIO1ChatCompletion(AzureChatCompletion):

async def mock_async_streaming(
self,
response: Any,
model: Optional[str],
logging_obj: Any,
):
model_response = await response
completion_stream = MockResponseIterator(model_response=model_response)
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="azure",
logging_obj=logging_obj,
)
return streaming_response

def completion(
class AzureOpenAIO1ChatCompletion(OpenAIChatCompletion):
def _get_openai_client(
self,
model: str,
messages: List,
model_response: ModelResponse,
api_key: str,
api_base: str,
api_version: str,
api_type: str,
azure_ad_token: str,
dynamic_params: bool,
print_verbose: Callable[..., Any],
timeout: Union[float, Timeout],
logging_obj: Logging,
optional_params,
litellm_params,
logger_fn,
acompletion: bool = False,
headers: Optional[dict] = None,
client=None,
):
stream: Optional[bool] = optional_params.pop("stream", False)
stream_options: Optional[dict] = optional_params.pop("stream_options", None)
response = super().completion(
model,
messages,
model_response,
api_key,
api_base,
api_version,
api_type,
azure_ad_token,
dynamic_params,
print_verbose,
timeout,
logging_obj,
optional_params,
litellm_params,
logger_fn,
acompletion,
headers,
client,
is_async: bool,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
max_retries: Optional[int] = 2,
organization: Optional[str] = None,
client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = None,
) -> Optional[
Union[
OpenAI,
AsyncOpenAI,
AzureOpenAI,
AsyncAzureOpenAI,
]
]:

# Override to use Azure-specific client initialization
if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI):
client = None

return get_azure_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
api_version=api_version,
client=client,
_is_async=is_async,
)

if stream is True:
if asyncio.iscoroutine(response):
return self.mock_async_streaming(
response=response, model=model, logging_obj=logging_obj # type: ignore
)

completion_stream = MockResponseIterator(model_response=response)
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="openai",
logging_obj=logging_obj,
stream_options=stream_options,
)

return streaming_response
else:
return response
31 changes: 31 additions & 0 deletions litellm/llms/azure/chat/o1_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,41 @@
- Temperature => drop param (if user opts in to dropping param)
"""

from typing import Optional

from litellm import verbose_logger
from litellm.utils import get_model_info

from ...openai.chat.o1_transformation import OpenAIO1Config


class AzureOpenAIO1Config(OpenAIO1Config):
def should_fake_stream(
self,
model: Optional[str],
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool:
"""
Currently no Azure OpenAI models support native streaming.
"""
if stream is not True:
return False

if model is not None:
try:
model_info = get_model_info(
model=model, custom_llm_provider=custom_llm_provider
)
if model_info.get("supports_native_streaming") is True:
return False
except Exception as e:
verbose_logger.debug(
f"Error getting model info in AzureOpenAIO1Config: {e}"
)

return True

def is_o1_model(self, model: str) -> bool:
o1_models = ["o1-mini", "o1-preview"]
for m in o1_models:
Expand Down
35 changes: 35 additions & 0 deletions litellm/llms/azure/common_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Callable, Optional, Union

import httpx
from openai import AsyncAzureOpenAI, AzureOpenAI

import litellm
from litellm._logging import verbose_logger
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret_str
Expand All @@ -25,6 +27,39 @@ def __init__(
)


def get_azure_openai_client(
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
api_version: Optional[str] = None,
organization: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
_is_async: bool = False,
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
received_args = locals()
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client" or k == "_is_async":
pass
elif k == "api_base" and v is not None:
data["azure_endpoint"] = v
elif v is not None:
data[k] = v
if "api_version" not in data:
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION
if _is_async is True:
openai_client = AsyncAzureOpenAI(**data)
else:
openai_client = AzureOpenAI(**data) # type: ignore
else:
openai_client = client

return openai_client


def process_azure_headers(headers: Union[httpx.Headers, dict]) -> dict:
openai_headers = {}
if "x-ratelimit-limit-requests" in headers:
Expand Down
34 changes: 1 addition & 33 deletions litellm/llms/azure/files/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,11 @@
from openai import AsyncAzureOpenAI, AzureOpenAI
from openai.types.file_deleted import FileDeleted

import litellm
from litellm._logging import verbose_logger
from litellm.llms.base import BaseLLM
from litellm.types.llms.openai import *


def get_azure_openai_client(
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
api_version: Optional[str] = None,
organization: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
_is_async: bool = False,
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
received_args = locals()
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client" or k == "_is_async":
pass
elif k == "api_base" and v is not None:
data["azure_endpoint"] = v
elif v is not None:
data[k] = v
if "api_version" not in data:
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION
if _is_async is True:
openai_client = AsyncAzureOpenAI(**data)
else:
openai_client = AzureOpenAI(**data) # type: ignore
else:
openai_client = client

return openai_client
from ..common_utils import get_azure_openai_client


class AzureOpenAIFilesAPI(BaseLLM):
Expand Down
Loading

0 comments on commit 347779b

Please sign in to comment.