Skip to content

Commit

Permalink
update healthck performance, improve client selection flexibility
Browse files Browse the repository at this point in the history
  • Loading branch information
trisongz committed Feb 28, 2024
1 parent b3a3150 commit cbeb65c
Show file tree
Hide file tree
Showing 10 changed files with 204 additions and 53 deletions.
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
# Changelogs

#### v0.0.52 (2024-02-28)

- Added support for the following parameters in `model_configurations` in `OpenAIManager`:

- `ping_timeout` - allows for custom timeouts for each client.

- `included_models` - allows for more flexible setting of models in Azure.

- `weight` - allows for weighted selection of clients.

- Improved Healthcheck behavior to cache if successful for a period of time, and always recheck if not.

- Added `dimension` parameter for `embedding` models.

#### v0.0.51rc (2024-02-07)

- Modification of `async_openai.types.context.ModelContextHandler` to a proxied object singleton.
Expand Down
39 changes: 24 additions & 15 deletions async_openai/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import aiohttpx
import contextlib
from typing import Optional, Callable, Dict, Union, List

from lazyops.utils.helpers import timed_cache
from async_openai.schemas import *
from async_openai.types.options import ApiType
from async_openai.utils.logs import logger
Expand Down Expand Up @@ -413,31 +413,40 @@ async def __aenter__(self):
async def __aexit__(self, exc_type, exc_value, traceback):
await self.async_close()


def ping(self, timeout: Optional[float] = 1.0) -> bool:
@timed_cache(secs = 120, cache_if_result = True)
def ping(self, timeout: Optional[float] = 1.0, base_url: Optional[str] = None) -> bool:
"""
Pings the API Endpoint to check if it's alive.
"""
try:
# with contextlib.suppress(Exception):
response = self.client.get('/', timeout = timeout)
data = response.json()
# we should expect a 404 with a json response
# if self.debug_enabled: logger.info(f"API Ping: {data}\n{response.headers}")
if data.get('error'): return True
response = self.client.get(base_url or '/', timeout = timeout)
try:
data = response.json()
# we should expect a 404 with a json response
# if self.debug_enabled: logger.info(f"API Ping: {data}\n{response.headers}")
if data.get('error'): return True
except Exception as e:
logger.error(f"[{self.name} - {response.status_code}] API Ping Failed: {response.text[:500]}")
except Exception as e:
logger.error(f"API Ping Failed: {e}")
logger.error(f"[{self.name}] API Ping Failed: {e}")
return False

async def aping(self, timeout: Optional[float] = 1.0) -> bool:
@timed_cache(secs = 120, cache_if_result = True)
async def aping(self, timeout: Optional[float] = 1.0, base_url: Optional[str] = None) -> bool:
"""
Pings the API Endpoint to check if it's alive.
"""
with contextlib.suppress(Exception):
response = await self.client.async_get('/', timeout = timeout)
data = response.json()
# we should expect a 404 with a json response
if data.get('error'): return True
try:
response = await self.client.async_get(base_url or '/', timeout = timeout)
try:
data = response.json()
# we should expect a 404 with a json response
if data.get('error'): return True
except Exception as e:
logger.error(f"[{self.name} - {response.status_code}] API Ping Failed: {response.text[:500]}")
except Exception as e:
logger.error(f"[{self.name}] API Ping Failed: {e}")
return False


Expand Down
4 changes: 3 additions & 1 deletion async_openai/external_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,11 +291,13 @@ async def aping(self, timeout: Optional[float] = 1.0) -> bool:
"""
Pings the API Endpoint to check if it's alive.
"""
with contextlib.suppress(Exception):
try:
response = await self.client.async_get('/', timeout = timeout)
data = response.json()
# we should expect a 404 with a json response
if data.get('error'): return True
except Exception as e:
logger.error(f"[{self.name}] API Ping Failed: {e}")
return False


27 changes: 20 additions & 7 deletions async_openai/loadbalancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ def get_api_client(self, client_name: Optional[str] = None, require_azure: Optio
client_name = 'default'
if client_name and client_name not in self.clients:
self.clients[client_name] = self.init_api_client(client_name = client_name, **kwargs)

if not client_name and require_azure:
while not self.api.is_azure:
self.increase_rotate_index()
Expand All @@ -212,33 +211,47 @@ def get_api_client_from_list(self, client_names: List[str], require_azure: Optio
Initializes a new OpenAI client or Returns an existing one from a list of client names.
"""
if not self.healthcheck:
name = random.choice(client_names)
name = self.manager.select_client_name_from_weights(client_names) if self.manager.has_client_weights else random.choice(client_names)
return self.get_api_client(client_name = name, require_azure = require_azure, **kwargs)
available = []
for client_name in client_names:
if client_name not in self.clients:
self.clients[client_name] = self.init_api_client(client_name = client_name, **kwargs)
if require_azure and not self.clients[client_name].is_azure:
continue
if not self.clients[client_name].ping():
if not self.clients[client_name].ping(**self.manager.get_client_ping_params(client_name)):
continue
return self.clients[client_name]
if not self.manager.has_client_weights:
return self.clients[client_name]
available.append(client_name)
# return self.clients[client_name]
if available:
name = self.manager.select_client_name_from_weights(available)
return self.clients[name]
raise ValueError(f'No healthy client found from: {client_names}')

async def aget_api_client_from_list(self, client_names: List[str], require_azure: Optional[bool] = None, **kwargs) -> 'OpenAIClient':
"""
Initializes a new OpenAI client or Returns an existing one from a list of client names.
"""
if not self.healthcheck:
name = random.choice(client_names)
name = self.manager.select_client_name_from_weights(client_names) if self.manager.has_client_weights else random.choice(client_names)
return self.get_api_client(client_name = name, require_azure = require_azure, **kwargs)
available = []
for client_name in client_names:
if client_name not in self.clients:
self.clients[client_name] = self.init_api_client(client_name = client_name, **kwargs)
if require_azure and not self.clients[client_name].is_azure:
continue
if not await self.clients[client_name].aping():
if not await self.clients[client_name].aping(**self.manager.get_client_ping_params(client_name)):
continue
return self.clients[client_name]
if not self.manager.has_client_weights:
return self.clients[client_name]
available.append(client_name)

if available:
name = self.manager.select_client_name_from_weights(available)
return self.clients[name]
raise ValueError(f'No healthy client found from: {client_names}')

def __getitem__(self, key: Union[str, int]) -> 'OpenAIClient':
Expand Down
103 changes: 93 additions & 10 deletions async_openai/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from async_openai.types.context import ModelContextHandler
from async_openai.utils.config import get_settings, OpenAISettings
from async_openai.utils.external_config import ExternalProviderSettings
from async_openai.utils.helpers import weighted_choice
from async_openai.types.functions import FunctionManager, OpenAIFunctions
from async_openai.utils.logs import logger

Expand All @@ -37,7 +38,17 @@
'gpt-3.5-turbo-0613': 'gpt-35-turbo-0613',
'gpt-3.5-turbo-1106': 'gpt-35-turbo-1106',
}

DefaultAvailableModels = [
'gpt-4',
'gpt-4-32k',
'gpt-4-turbo',
'gpt-4-1106-preview',
'gpt-4-0125-preview',
'text-embedding-ada-2',
'text-embedding-3-small',
'text-embedding-3-large',
] + list(DefaultModelMapping.values())

class OpenAIManager(abc.ABC):
name: Optional[str] = "openai"
on_error: Optional[Callable] = None
Expand All @@ -61,7 +72,11 @@ def __init__(self, **kwargs):
"""
Initializes the OpenAI API Client
"""
self.client_weights: Optional[Dict[str, float]] = {}
self.client_ping_timeouts: Optional[Dict[str, float]] = {}
self.client_model_exclusions: Optional[Dict[str, Dict[str, Union[bool, List[str]]]]] = {}
self.client_base_urls: Optional[Dict[str, str]] = {}

self.no_proxy_client_names: Optional[List[str]] = []
self.client_callbacks: Optional[List[Callable]] = []
self.functions: FunctionManager = OpenAIFunctions
Expand All @@ -74,6 +89,13 @@ def __init__(self, **kwargs):
self.external_client_default: Optional[str] = None
# self._external_clients:

@property
def has_client_weights(self) -> bool:
"""
Returns if the client has weights
"""
return bool(self.client_weights)

def add_callback(self, callback: Callable):
"""
Adds a callback to the client
Expand Down Expand Up @@ -271,16 +293,24 @@ def get_api_client_from_list(
if not client_names: return self.apis.get_api_client(**kwargs)
return self.apis.get_api_client_from_list(client_names = client_names, **kwargs)
if not client_names: return self.get_api_client(**kwargs)

if not self.auto_healthcheck:
name = random.choice(client_names)
name = self.select_client_name_from_weights(client_names) if self.has_client_weights else random.choice(client_names)
return self.get_api_client(client_name = name, **kwargs)

available = []
for client_name in client_names:
if client_name not in self._clients:
self._clients[client_name] = self.init_api_client(client_name = client_name, **kwargs)
if not self._clients[client_name].ping():
if not self._clients[client_name].ping(**self.get_client_ping_params(client_name)):
continue
return self._clients[client_name]
if not self.has_client_weights:
return self._clients[client_name]
available.append(client_name)

if available:
name = self.select_client_name_from_weights(available)
return self._clients[name]
raise ValueError(f'No healthy client found from: {client_names}')

async def aget_api_client_from_list(
Expand All @@ -296,15 +326,22 @@ async def aget_api_client_from_list(
return await self.apis.aget_api_client_from_list(client_name = client_name, **kwargs)
if not client_names: return self.get_api_client(**kwargs)
if not self.auto_healthcheck:
name = random.choice(client_names)
name = self.select_client_name_from_weights(client_names) if self.has_client_weights else random.choice(client_names)
return self.get_api_client(client_name = name, **kwargs)

available = []
for client_name in client_names:
if client_name not in self._clients:
self._clients[client_name] = self.init_api_client(client_name = client_name, **kwargs)
if not await self._clients[client_name].aping():
if not await self._clients[client_name].aping(**self.get_client_ping_params(client_name)):
continue
return self._clients[client_name]
if not self.has_client_weights:
return self._clients[client_name]
available.append(client_name)

if available:
name = self.select_client_name_from_weights(available)
return self._clients[name]
raise ValueError(f'No healthy client found from: {client_names}')


Expand Down Expand Up @@ -535,6 +572,26 @@ def _ensure_api(self):
"""
if self._api is None: self.configure_internal_apis()

def select_client_name_from_weights(self, names: List[str]) -> str:
"""
Returns the client weights
"""
return weighted_choice([(name, self.client_weights.get(name, 1.0)) for name in names])

# def get_client_ping_timeout(self, name: str) -> Optional[float]:
# """
# Returns the client timeout
# """
# return self.client_ping_timeouts.get(name, 1.0)

def get_client_ping_params(self, name: str) -> Dict[str, Union[float, str]]:
"""
Returns the client ping parameters
"""
return {
'timeout': self.client_ping_timeouts.get(name, 1.0),
'base_url': self.client_base_urls.get(name),
}

"""
API Routes
Expand Down Expand Up @@ -640,32 +697,47 @@ def register_default_endpoints(self):
self.init_api_client('azure', is_azure = True)


def register_client_endpoints(self):
def register_client_endpoints(self): # sourcery skip: low-code-quality
"""
Register the Client Endpoints
"""
client_configs = copy.deepcopy(self.settings.client_configurations)
seen_models = set(DefaultAvailableModels)
has_weights = any(c.get('weight') for c in client_configs.values())
for name, config in client_configs.items():
is_enabled = config.pop('enabled', False)
if not is_enabled: continue
is_azure = 'azure' in name or 'az' in name or config.get('is_azure', False)
is_default = config.pop('default', False)
proxy_disabled = config.pop('proxy_disabled', False)
source_endpoint = config.get('api_base')
client_weight = config.pop('weight', 1.0) if has_weights else None
client_ping_timeout = config.pop('ping_timeout', None)

if self.debug_enabled is not None: config['debug_enabled'] = self.debug_enabled
if excluded_models := config.pop('excluded_models', None):
self.client_model_exclusions[name] = {
'models': excluded_models, 'is_azure': is_azure,
}
seen_models.update(excluded_models)
else:
self.client_model_exclusions[name] = {
'models': None, 'is_azure': is_azure,
}

if included_models := config.pop('included_models', None):
self.client_model_exclusions[name]['included_models'] = included_models

if client_weight: self.client_weights[name] = client_weight
if client_ping_timeout is not None: self.client_ping_timeouts[name] = client_ping_timeout

if (self.settings.proxy.enabled and not proxy_disabled) and config.get('api_base'):
# Initialize a non-proxy version of the client
config['api_base'] = source_endpoint
non_proxy_name = f'{name}_noproxy'
self.client_base_urls[name] = source_endpoint
if client_weight: self.client_weights[non_proxy_name] = client_weight
if client_ping_timeout is not None: self.client_ping_timeouts[non_proxy_name] = client_ping_timeout
self.client_model_exclusions[non_proxy_name] = self.client_model_exclusions[name].copy()
self.no_proxy_client_names.append(non_proxy_name)
self.init_api_client(non_proxy_name, is_azure = is_azure, set_as_default = False, **config)
Expand All @@ -675,8 +747,19 @@ def register_client_endpoints(self):
)
config['api_base'] = self.settings.proxy.endpoint
c = self.init_api_client(name, is_azure = is_azure, set_as_default = is_default, **config)
logger.info(f'Registered: `|g|{c.name}|e|` @ `{source_endpoint or c.base_url}` (Azure: {c.is_azure})', colored = True)

msg = f'Registered: `|g|{c.name}|e|` @ `{source_endpoint or c.base_url}` (Azure: {c.is_azure}'
if has_weights: msg += f', Weight: {client_weight}'
msg += ')'
logger.info(msg, colored = True)

# Set the models for inclusion
for name in self.client_model_exclusions:
if not self.client_model_exclusions[name].get('included_models'): continue
included_models = self.client_model_exclusions[name].pop('included_models')
self.client_model_exclusions[name]['models'] = [m for m in seen_models if m not in included_models]
# if self.settings.debug_enabled:
# logger.info(f'|g|{name}|e| Included: {included_models}, Excluded: {self.client_model_exclusions[name]["models"]}', colored = True)


def select_client_names(
self,
Expand Down
Loading

0 comments on commit cbeb65c

Please sign in to comment.