Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
trisongz committed Feb 23, 2024
1 parent ce50aaf commit 9b57c83
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 12 deletions.
2 changes: 2 additions & 0 deletions async_openai/external_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(

self.is_proxied = is_proxied if is_proxied is not None else \
(self.provider.config.has_proxy and '_noproxy' not in self.name)
# logger.info(f"External Provider Configured: {self.name} [Proxied: {self.is_proxied}]")

self.settings: Optional[OpenAISettings] = kwargs.pop('settings', get_settings())
self.client_callbacks: List[Callable] = []
Expand Down Expand Up @@ -151,6 +152,7 @@ def configure_client(self, **kwargs):
},
**extra_kwargs,
)
# logger.info(f"External Configured: {self._client.base_url} [{self.name}]")

def configure_routes(self, **kwargs):
"""
Expand Down
21 changes: 13 additions & 8 deletions async_openai/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,7 @@ def select_external_client_names(
else: client_names = self.external_client_names
if noproxy_required: client_names = [c for c in client_names if 'noproxy' in c]
else: client_names = [c for c in client_names if 'noproxy' not in c]
# logger.info(f'External Clients: {client_names}')
if excluded_clients: client_names = [c for c in client_names if c not in excluded_clients]
return list(set(client_names))

Expand Down Expand Up @@ -1302,6 +1303,8 @@ def create_embeddings(
model: Optional[str] = None,
auto_retry: Optional[bool] = True,
strip_newlines: Optional[bool] = False,
headers: Optional[Dict[str, str]] = None,
noproxy_required: Optional[bool] = False,
**kwargs,
) -> List[List[float]]:
"""
Expand All @@ -1322,20 +1325,20 @@ def create_embeddings(
**kwargs
)
if strip_newlines: inputs = [i.replace('\n', ' ').strip() for i in inputs]
client = self.get_client(model = model, **kwargs)
client = self.get_client(model = model, noproxy_required = noproxy_required, **kwargs)
if not client.is_azure:
response = client.embeddings.create(input = inputs, model = model, auto_retry = auto_retry, **kwargs)
response = client.embeddings.create(input = inputs, model = model, auto_retry = auto_retry, headers = headers, **kwargs)
return response.embeddings

embeddings = []
# We need to split into batches of 5 for Azure
# Azure has a limit of 5 inputs per request
batches = split_into_batches(inputs, 5)
for batch in batches:
response = client.embeddings.create(input = batch, model = model, auto_retry = auto_retry, **kwargs)
response = client.embeddings.create(input = batch, model = model, auto_retry = auto_retry, headers = headers, **kwargs)
embeddings.extend(response.embeddings)
# Shuffle the clients to load balance
client = self.get_client(model = model, azure_required = True, **kwargs)
client = self.get_client(model = model, azure_required = True, noproxy_required = noproxy_required, **kwargs)
return embeddings


Expand All @@ -1345,6 +1348,8 @@ async def async_create_embeddings(
model: Optional[str] = None,
auto_retry: Optional[bool] = True,
strip_newlines: Optional[bool] = False,
headers: Optional[Dict[str, str]] = None,
noproxy_required: Optional[bool] = False,
**kwargs,
) -> List[List[float]]:
"""
Expand All @@ -1365,20 +1370,20 @@ async def async_create_embeddings(
**kwargs
)
if strip_newlines: inputs = [i.replace('\n', ' ').strip() for i in inputs]
client = self.get_client(model = model, **kwargs)
client = self.get_client(model = model, noproxy_required = noproxy_required, **kwargs)
if not client.is_azure:
response = await client.embeddings.async_create(input = inputs, model = model, auto_retry = auto_retry, **kwargs)
response = await client.embeddings.async_create(input = inputs, model = model, auto_retry = auto_retry, headers = headers, **kwargs)
return response.embeddings

embeddings = []
# We need to split into batches of 5 for Azure
# Azure has a limit of 5 inputs per request
batches = split_into_batches(inputs, 5)
for batch in batches:
response = await client.embeddings.async_create(input = batch, model = model, auto_retry = auto_retry, **kwargs)
response = await client.embeddings.async_create(input = batch, model = model, auto_retry = auto_retry, headers = headers, **kwargs)
embeddings.extend(response.embeddings)
# Shuffle the clients to load balance
client = self.get_client(model = model, azure_required = True, **kwargs)
client = self.get_client(model = model, azure_required = True, noproxy_required = noproxy_required, **kwargs)
return embeddings

acreate_embeddings = async_create_embeddings
Expand Down
6 changes: 4 additions & 2 deletions async_openai/types/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,14 +410,15 @@ async def arun_chat_function(
functions = functions or self.functions
try:
if headers: chat.client.headers.update(headers)
function_call = "auto" if function_name and function_name == 'auto' else {'name': function_name or self.function_name or self.name}
return await chat.async_create(
model = model,
messages = messages,
functions = functions,
headers = headers,
auto_retry = True,
auto_retry_limit = 2,
function_call = {'name': function_name or self.name},
function_call = function_call,
header_cache_keys = ['Helicone-Cache-Enabled'],
**kwargs,
)
Expand Down Expand Up @@ -480,13 +481,14 @@ def run_chat_function(
headers.update(property_meta)
if headers: chat.client.headers.update(headers)
functions = functions or self.functions
function_call = "auto" if function_name and function_name == 'auto' else {'name': function_name or self.function_name or self.name}
return chat.create(
messages = messages,
functions = functions,
headers = headers,
auto_retry = True,
auto_retry_limit = self.retry_limit,
function_call = {'name': function_name or self.name},
function_call = function_call,
header_cache_keys=['Helicone-Cache-Enabled'],
**kwargs,
)
Expand Down
2 changes: 2 additions & 0 deletions async_openai/utils/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def _initialize_distance_dict(*args, **kwargs) -> Dict[str, Callable[..., float]
"""
return {
"cosine": spatial.distance.cosine,
"euclidean": spatial.distance.euclidean,
"inner_product": lambda x, y: -np.dot(x, y),
"L1": spatial.distance.cityblock,
"L2": spatial.distance.euclidean,
"Linf": spatial.distance.chebyshev,
Expand Down
2 changes: 1 addition & 1 deletion async_openai/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION = '0.0.51rc1'
VERSION = '0.0.51rc2'
3 changes: 2 additions & 1 deletion tests/external_provider.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

os.environ['TOGETHER_API_KEY'] = 'test123'
# os.environ['TOGETHER_API_KEY'] = 'test123'
os.environ['TOGETHER_API_KEYS'] = '[test1253, test4565]'

from async_openai.utils.external_config import ExternalProviderSettings

Expand Down

0 comments on commit 9b57c83

Please sign in to comment.