From 60194818e0408d29b74b5c94a4e40aaface5f82e Mon Sep 17 00:00:00 2001 From: Logan Date: Wed, 18 Dec 2024 21:40:01 -0600 Subject: [PATCH] improve async search client handling (#17319) --- .../vector_stores/azureaisearch/base.py | 167 +++++++++++------- .../pyproject.toml | 2 +- 2 files changed, 104 insertions(+), 65 deletions(-) diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/llama_index/vector_stores/azureaisearch/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/llama_index/vector_stores/azureaisearch/base.py index 0601a408fc41c..4bac8671043b1 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/llama_index/vector_stores/azureaisearch/base.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/llama_index/vector_stores/azureaisearch/base.py @@ -4,7 +4,7 @@ import json import logging from enum import auto -from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from azure.search.documents import SearchClient from azure.search.documents.aio import SearchClient as AsyncSearchClient @@ -523,12 +523,17 @@ async def _avalidate_index(self, index_name: Optional[str]) -> None: def __init__( self, - search_or_index_client: Any, + search_or_index_client: Union[ + SearchClient, SearchIndexClient, AsyncSearchClient, AsyncSearchIndexClient + ], id_field_key: str, chunk_field_key: str, embedding_field_key: str, metadata_string_field_key: str, doc_id_field_key: str, + async_search_or_index_client: Optional[ + Union[AsyncSearchClient, AsyncSearchIndexClient] + ] = None, filterable_metadata_field_keys: Optional[ Union[ List[str], @@ -617,13 +622,6 @@ def __init__( self._user_agent = ( f"{base_user_agent} {user_agent}" if user_agent else base_user_agent ) - - self._index_client: SearchIndexClient = cast(SearchIndexClient, None) - self._async_index_client: AsyncSearchIndexClient = cast( - AsyncSearchIndexClient, None - ) - self._search_client: SearchClient = cast(SearchClient, None) - self._async_search_client: AsyncSearchClient = cast(AsyncSearchClient, None) self._embedding_dimensionality = embedding_dimensionality self._index_name = index_name @@ -639,11 +637,22 @@ def __init__( self._language_analyzer = language_analyzer self._compression_type = compression_type.lower() - # Validate search_or_index_client + # Initialize clients to None + self._index_client = None + self._async_index_client = None + self._search_client = None + self._async_search_client = None + + if search_or_index_client and async_search_or_index_client is None: + logger.warning( + "async_search_or_index_client is None. Depending on the client type passed " + "in, sync or async functions may not work." + ) + + # Validate sync search_or_index_client if search_or_index_client is not None: if isinstance(search_or_index_client, SearchIndexClient): - # If SearchIndexClient is supplied so must index_name - self._index_client = cast(SearchIndexClient, search_or_index_client) + self._index_client = search_or_index_client self._index_client._client._config.user_agent_policy.add_user_agent( self._user_agent ) @@ -660,18 +669,32 @@ def __init__( self._user_agent ) - elif isinstance(search_or_index_client, AsyncSearchIndexClient): - # If SearchIndexClient is supplied so must index_name - self._async_index_client = cast( - AsyncSearchIndexClient, search_or_index_client + elif isinstance(search_or_index_client, SearchClient): + self._search_client = search_or_index_client + self._search_client._client._config.user_agent_policy.add_user_agent( + self._user_agent ) + # Validate index_name + if index_name: + raise ValueError( + "index_name cannot be supplied if search_or_index_client " + "is of type azure.search.documents.SearchClient" + ) + + # Validate async search_or_index_client -- if not provided, assume the search_or_index_client could be async + async_search_or_index_client = ( + async_search_or_index_client or search_or_index_client + ) + if async_search_or_index_client is not None: + if isinstance(async_search_or_index_client, AsyncSearchIndexClient): + self._async_index_client = async_search_or_index_client self._async_index_client._client._config.user_agent_policy.add_user_agent( self._user_agent ) if not index_name: raise ValueError( - "index_name must be supplied if search_or_index_client is of " + "index_name must be supplied if async_search_or_index_client is of " "type azure.search.documents.aio.SearchIndexClient" ) @@ -682,22 +705,8 @@ def __init__( self._user_agent ) - elif isinstance(search_or_index_client, SearchClient): - self._search_client = cast(SearchClient, search_or_index_client) - self._search_client._client._config.user_agent_policy.add_user_agent( - self._user_agent - ) - # Validate index_name - if index_name: - raise ValueError( - "index_name cannot be supplied if search_or_index_client " - "is of type azure.search.documents.SearchClient" - ) - - elif isinstance(search_or_index_client, AsyncSearchClient): - self._async_search_client = cast( - AsyncSearchClient, search_or_index_client - ) + elif isinstance(async_search_or_index_client, AsyncSearchClient): + self._async_search_client = async_search_or_index_client self._async_search_client._client._config.user_agent_policy.add_user_agent( self._user_agent ) @@ -705,35 +714,31 @@ def __init__( # Validate index_name if index_name: raise ValueError( - "index_name cannot be supplied if search_or_index_client " - "is of type azure.search.documents.SearchClient" + "index_name cannot be supplied if async_search_or_index_client " + "is of type azure.search.documents.aio.SearchClient" ) - if isinstance(search_or_index_client, AsyncSearchIndexClient): - if not self._async_index_client and not self._async_search_client: - raise ValueError( - "search_or_index_client must be of type " - "azure.search.documents.SearchIndexClient or " - "azure.search.documents.SearchClient" - ) - - if isinstance(search_or_index_client, SearchIndexClient): - if not self._index_client and not self._search_client: - raise ValueError( - "search_or_index_client must be of type " - "azure.search.documents.SearchIndexClient or " - "azure.search.documents.SearchClient" - ) - else: - raise ValueError("search_or_index_client not specified") + # Validate that at least one client was provided + if not any( + [ + self._search_client, + self._async_search_client, + self._index_client, + self._async_index_client, + ] + ): + raise ValueError( + "Either search_or_index_client or async_search_or_index_client must be provided" + ) + # Validate index management requirements if index_management == IndexManagement.CREATE_IF_NOT_EXISTS and not ( self._index_client or self._async_index_client ): raise ValueError( "index_management has value of IndexManagement.CREATE_IF_NOT_EXISTS " - "but search_or_index_client is not of type " - "azure.search.documents.SearchIndexClient or azure.search.documents.aio.SearchIndexClient " + "but neither search_or_index_client nor async_search_or_index_client is of type " + "azure.search.documents.SearchIndexClient or azure.search.documents.aio.SearchIndexClient" ) self._index_management = index_management @@ -1161,20 +1166,36 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul odata_filter = self._create_odata_filter(query.filters) azure_query_result_search: AzureQueryResultSearchBase = ( AzureQueryResultSearchDefault( - query, self._field_mapping, odata_filter, self._search_client + query, + self._field_mapping, + odata_filter, + self._search_client, + self._async_search_client, ) ) if query.mode == VectorStoreQueryMode.SPARSE: azure_query_result_search = AzureQueryResultSearchSparse( - query, self._field_mapping, odata_filter, self._search_client + query, + self._field_mapping, + odata_filter, + self._search_client, + self._async_search_client, ) elif query.mode == VectorStoreQueryMode.HYBRID: azure_query_result_search = AzureQueryResultSearchHybrid( - query, self._field_mapping, odata_filter, self._search_client + query, + self._field_mapping, + odata_filter, + self._search_client, + self._async_search_client, ) elif query.mode == VectorStoreQueryMode.SEMANTIC_HYBRID: azure_query_result_search = AzureQueryResultSearchSemanticHybrid( - query, self._field_mapping, odata_filter, self._search_client + query, + self._field_mapping, + odata_filter, + self._search_client, + self._async_search_client, ) return azure_query_result_search.search() @@ -1193,20 +1214,36 @@ async def aquery( azure_query_result_search: AzureQueryResultSearchBase = ( AzureQueryResultSearchDefault( - query, self._field_mapping, odata_filter, self._async_search_client + query, + self._field_mapping, + odata_filter, + self._search_client, + self._async_search_client, ) ) if query.mode == VectorStoreQueryMode.SPARSE: azure_query_result_search = AzureQueryResultSearchSparse( - query, self._field_mapping, odata_filter, self._async_search_client + query, + self._field_mapping, + odata_filter, + self._search_client, + self._async_search_client, ) elif query.mode == VectorStoreQueryMode.HYBRID: azure_query_result_search = AzureQueryResultSearchHybrid( - query, self._field_mapping, odata_filter, self._async_search_client + query, + self._field_mapping, + odata_filter, + self._search_client, + self._async_search_client, ) elif query.mode == VectorStoreQueryMode.SEMANTIC_HYBRID: azure_query_result_search = AzureQueryResultSearchSemanticHybrid( - query, self._field_mapping, odata_filter, self._async_search_client + query, + self._field_mapping, + odata_filter, + self._search_client, + self._async_search_client, ) return await azure_query_result_search.asearch() @@ -1339,12 +1376,14 @@ def __init__( query: VectorStoreQuery, field_mapping: Dict[str, str], odata_filter: Optional[str], - search_client: Any, + search_client: SearchClient, + async_search_client: AsyncSearchClient, ) -> None: self._query = query self._field_mapping = field_mapping self._odata_filter = odata_filter self._search_client = search_client + self._async_search_client = async_search_client @property def _select_fields(self) -> List[str]: @@ -1417,7 +1456,7 @@ def _create_query_result( async def _acreate_query_result( self, search_query: str, vectors: Optional[List[Any]] ) -> VectorStoreQueryResult: - results = await self._search_client.search( + results = await self._async_search_client.search( search_text=search_query, vector_queries=vectors, top=self._query.similarity_top_k, diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/pyproject.toml b/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/pyproject.toml index 73798a43b5b61..85bfd1c674ef9 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/pyproject.toml +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-azureaisearch/pyproject.toml @@ -28,7 +28,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-vector-stores-azureaisearch" readme = "README.md" -version = "0.3.0" +version = "0.3.1" [tool.poetry.dependencies] python = ">=3.9,<4.0"