diff --git a/docs/sync.rst b/docs/sync.rst index 04333f6b..53c461ea 100644 --- a/docs/sync.rst +++ b/docs/sync.rst @@ -1,7 +1,7 @@ Synchronous Clients ===================== -This module contains clients with coroutine methods used for asynchronous programming. +This module contains clients with methods used for synchronous programming. .. toctree:: diff --git a/pyproject.toml b/pyproject.toml index c261e948..4ba77a03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Topic :: Database" ] -version = "1.0.0" +version = "1.0.1" requires-python = ">3.8" dependencies = [ "grpcio == 1.64.1", diff --git a/src/aerospike_vector_search/admin.py b/src/aerospike_vector_search/admin.py index b4cf9aba..cb7e33b0 100644 --- a/src/aerospike_vector_search/admin.py +++ b/src/aerospike_vector_search/admin.py @@ -17,7 +17,7 @@ class Client(BaseClient): This client is designed to conduct Aerospike Vector Search administrative operation such as creating indexes, querying index information, and dropping indexes. - :param seeds: Defines the Aerospike Database cluster nodes to which you want AVS to connect. AVS iterates through the seed nodes. After connecting to a node, AVS discovers all of the nodes in the cluster. + :param seeds: Defines the AVS nodes to which you want AVS to connect. AVS iterates through the seed nodes. After connecting to a node, AVS discovers all of the nodes in the cluster. :type seeds: Union[types.HostPort, tuple[types.HostPort, ...]] :param listener_name: An external (NATed) address and port combination that differs from the actual address and port where AVS is listening. Clients can access AVS on a node using the advertised listener address and port. Defaults to None. @@ -38,10 +38,10 @@ class Client(BaseClient): :param root_certificate: The PEM-encoded root certificates as a byte string. Defaults to None. :type root_certificate: Optional[list[bytes], bytes] - :param certificate_chain: The PEM-encoded private key as a byte string. Defaults to None. + :param certificate_chain: The PEM-encoded certificate chain as a byte string. Defaults to None. :type certificate_chain: Optional[bytes] - :param private_key: The PEM-encoded certificate chain as a byte string. Defaults to None. + :param private_key: The PEM-encoded private key as a byte string. Defaults to None. :type private_key: Optional[bytes] :raises AVSClientError: Raised when no seed host is provided. @@ -109,21 +109,24 @@ def index_create( :param vector_distance_metric: The distance metric used to compare when performing a vector search. Defaults to :class:`VectorDistanceMetric.SQUARED_EUCLIDEAN`. - :type dimensions: Optional[types.VectorDistanceMetric] + :type vector_distance_metric: Optional[types.VectorDistanceMetric] :param sets: The set used for the index. Defaults to None. - :type dimensions: Optional[str] + :type sets: Optional[str] :param index_params: (Optional[types.HnswParams], optional): Parameters used for tuning vector search. Defaults to None. If index_params is None, then the default values specified for :class:`types.HnswParams` will be used. - :type dimensions: Optional[types.HnswParams] + :type index_params: Optional[types.HnswParams] :param index_labels: Meta data associated with the index. Defaults to None. - :type dimensions: Optional[dict[str, str]] + :type index_labels: Optional[dict[str, str]] + + :param index_storage: Namespace and set where index overhead (non-vector data) is stored. + :type index_storage: Optional[types.IndexStorage] :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Raises: AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. @@ -178,10 +181,10 @@ def index_drop( :type name: str :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to drop the index.. + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to drop the index. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. Note: @@ -215,12 +218,12 @@ def index_list(self, timeout: Optional[int] = None) -> list[dict]: List all indices. :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Returns: list[dict]: A list of indices. Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list the index.. + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list the index. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ @@ -246,18 +249,18 @@ def index_get( Retrieve the information related with an index. :param namespace: The namespace of the index. - :type name: str + :type namespace: str :param name: The name of the index. :type name: str :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Returns: dict[str, Union[int, str]: Information about an index. Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get the index.. + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get the index. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ @@ -290,7 +293,7 @@ def index_get_status( :type name: str :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Returns: int: Records queued to be merged into an index. @@ -341,7 +344,7 @@ def add_user( :type password: list[str] :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Raises: @@ -372,11 +375,11 @@ def update_credentials( :param username: Username of the user to update. :type username: str - :param password: New password for the userr. + :param password: New password for the user. :type password: str :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Raises: @@ -406,7 +409,7 @@ def drop_user(self, *, username: str, timeout: Optional[int] = None) -> None: :type username: str :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Raises: @@ -430,13 +433,13 @@ def drop_user(self, *, username: str, timeout: Optional[int] = None) -> None: def get_user(self, *, username: str, timeout: Optional[int] = None) -> types.User: """ - Retrieves AVS User information from the AVS Server + Retrieves AVS User information from the AVS Server. :param username: Username of the user to be retrieved. :type username: str :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int return: types.User: AVS User @@ -466,7 +469,7 @@ def list_users(self, timeout: Optional[int] = None) -> list[types.User]: List all users existing on the AVS Server. :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int return: list[types.User]: list of AVS Users @@ -494,7 +497,7 @@ def grant_roles( self, *, username: str, roles: list[str], timeout: Optional[int] = None ) -> None: """ - grant roles to existing AVS Users. + Grant roles to existing AVS Users. :param username: Username of the user which will receive the roles. :type username: str @@ -503,7 +506,7 @@ def grant_roles( :type roles: list[str] :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Raises: AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to grant roles. @@ -528,16 +531,16 @@ def revoke_roles( self, *, username: str, roles: list[str], timeout: Optional[int] = None ) -> None: """ - grant roles to existing AVS Users. + Revoke roles from existing AVS Users. :param username: Username of the user undergoing role removal. :type username: str - :param roles: Roles the specified user will no longer maintain.. + :param roles: Roles to be revoked. :type roles: list[str] :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Raises: AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to revoke roles. @@ -560,10 +563,10 @@ def revoke_roles( def list_roles(self, timeout: Optional[int] = None) -> list[dict]: """ - grant roles to existing AVS Users. + List roles available on the AVS server. :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int returns: list[str]: Roles available in the AVS Server. diff --git a/src/aerospike_vector_search/aio/admin.py b/src/aerospike_vector_search/aio/admin.py index 3a6475f4..ba8fea63 100644 --- a/src/aerospike_vector_search/aio/admin.py +++ b/src/aerospike_vector_search/aio/admin.py @@ -17,7 +17,7 @@ class Client(BaseClient): This client is designed to conduct Aerospike Vector Search administrative operation such as creating indexes, querying index information, and dropping indexes. - :param seeds: Defines the Aerospike Database cluster nodes to which you want AVS to connect. AVS iterates through the seed nodes. After connecting to a node, AVS discovers all of the nodes in the cluster. + :param seeds: Defines the AVS nodes to which you want AVS to connect. AVS iterates through the seed nodes. After connecting to a node, AVS discovers all of the nodes in the cluster. :type seeds: Union[types.HostPort, tuple[types.HostPort, ...]] :param listener_name: An external (NATed) address and port combination that differs from the actual address and port where AVS is listening. Clients can access AVS on a node using the advertised listener address and port. Defaults to None. @@ -38,10 +38,10 @@ class Client(BaseClient): :param root_certificate: The PEM-encoded root certificates as a byte string. Defaults to None. :type root_certificate: Optional[list[bytes], bytes] - :param certificate_chain: The PEM-encoded private key as a byte string. Defaults to None. + :param certificate_chain: The PEM-encoded certificate chain as a byte string. Defaults to None. :type certificate_chain: Optional[bytes] - :param private_key: The PEM-encoded certificate chain as a byte string. Defaults to None. + :param private_key: The PEM-encoded private key as a byte string. Defaults to None. :type private_key: Optional[bytes] :raises AVSClientError: Raised when no seed host is provided. @@ -109,21 +109,24 @@ async def index_create( :param vector_distance_metric: The distance metric used to compare when performing a vector search. Defaults to :class:`VectorDistanceMetric.SQUARED_EUCLIDEAN`. - :type dimensions: Optional[types.VectorDistanceMetric] + :type vector_distance_metric: Optional[types.VectorDistanceMetric] :param sets: The set used for the index. Defaults to None. - :type dimensions: Optional[str] + :type sets: Optional[str] :param index_params: (Optional[types.HnswParams], optional): Parameters used for tuning vector search. Defaults to None. If index_params is None, then the default values specified for :class:`types.HnswParams` will be used. - :type dimensions: Optional[types.HnswParams] + :type index_params: Optional[types.HnswParams] :param index_labels: Meta data associated with the index. Defaults to None. - :type dimensions: Optional[dict[str, str]] + :type index_labels: Optional[dict[str, str]] + :param index_storage: Namespace and set where index overhead (non-vector data) is stored. + :type index_storage: Optional[types.IndexStorage] + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Raises: AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. @@ -180,10 +183,10 @@ async def index_drop( :type name: str :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to drop the index.. + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to drop the index. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. Note: @@ -218,12 +221,12 @@ async def index_list(self, timeout: Optional[int] = None) -> list[dict]: List all indices. :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Returns: list[dict]: A list of indices. Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list the index.. + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list the index. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ await self._channel_provider._is_ready() @@ -256,12 +259,12 @@ async def index_get( :type name: str :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Returns: dict[str, Union[int, str]: Information about an index. Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get the index.. + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get the index. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ @@ -289,13 +292,13 @@ async def index_get_status( Retrieve the number of records queued to be merged into an index. :param namespace: The namespace of the index. - :type name: str + :type namespace: str :param name: The name of the index. :type name: str :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Returns: int: Records queued to be merged into an index. @@ -348,7 +351,7 @@ async def add_user( :type password: list[str] :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Raises: @@ -381,11 +384,11 @@ async def update_credentials( :param username: Username of the user to update. :type username: str - :param password: New password for the userr. + :param password: New password for the user. :type password: str :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Raises: @@ -417,7 +420,7 @@ async def drop_user(self, *, username: str, timeout: Optional[int] = None) -> No :type username: str :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Raises: @@ -445,13 +448,13 @@ async def get_user( self, *, username: str, timeout: Optional[int] = None ) -> types.User: """ - Retrieves AVS User information from the AVS Server + Retrieves AVS User information from the AVS Server. :param username: Username of the user to be retrieved. :type username: str :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int return: types.User: AVS User @@ -483,7 +486,7 @@ async def list_users(self, timeout: Optional[int] = None) -> list[types.User]: List all users existing on the AVS Server. :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int return: list[types.User]: list of AVS Users @@ -513,7 +516,7 @@ async def grant_roles( self, *, username: str, roles: list[str], timeout: Optional[int] = None ) -> int: """ - grant roles to existing AVS Users. + Grant roles to existing AVS Users. :param username: Username of the user which will receive the roles. :type username: str @@ -522,7 +525,7 @@ async def grant_roles( :type roles: list[str] :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Raises: AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to grant roles. @@ -549,16 +552,16 @@ async def revoke_roles( self, *, username: str, roles: list[str], timeout: Optional[int] = None ) -> int: """ - grant roles to existing AVS Users. + Revoke roles from existing AVS Users. :param username: Username of the user undergoing role removal. :type username: str - :param roles: Roles the specified user will no longer maintain.. + :param roles: Roles to be revoked. :type roles: list[str] :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Raises: AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to revoke roles. @@ -583,10 +586,10 @@ async def revoke_roles( async def list_roles(self, timeout: Optional[int] = None) -> None: """ - grant roles to existing AVS Users. + list roles of existing AVS Users. :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int returns: list[str]: Roles available in the AVS Server. diff --git a/src/aerospike_vector_search/aio/client.py b/src/aerospike_vector_search/aio/client.py index 5e09cbb7..035308dd 100644 --- a/src/aerospike_vector_search/aio/client.py +++ b/src/aerospike_vector_search/aio/client.py @@ -21,7 +21,7 @@ class Client(BaseClient): Moreover, the client supports Hierarchical Navigable Small World (HNSW) vector searches, allowing users to find vectors similar to a given query vector within an index. - :param seeds: Defines the Aerospike Database cluster nodes to which you want AVS to connect. AVS iterates through the seed nodes. After connecting to a node, AVS discovers all of the nodes in the cluster. + :param seeds: Defines the AVS nodes to which you want AVS to connect. AVS iterates through the seed nodes. After connecting to a node, AVS discovers all of the nodes in the cluster. :type seeds: Union[types.HostPort, tuple[types.HostPort, ...]] :param listener_name: An external (NATed) address and port combination that differs from the actual address and port where AVS is listening. Clients can access AVS on a node using the advertised listener address and port. Defaults to None. @@ -42,12 +42,12 @@ class Client(BaseClient): :param root_certificate: The PEM-encoded root certificates as a byte string. Defaults to None. :type root_certificate: Optional[list[bytes], bytes] - :param certificate_chain: The PEM-encoded private key as a byte string. Defaults to None. - :type certificate_chain: Optional[bytes] - - :param private_key: The PEM-encoded certificate chain as a byte string. Defaults to None. + :param private_key: The PEM-encoded private key as a byte string. Defaults to None. :type private_key: Optional[bytes] + :param certificate_chain: The PEM-encoded certificate chain as a byte string. Defaults to None. + :type certificate_chain: Optional[bytes] + :raises AVSClientError: Raised when no seed host is provided. """ @@ -107,15 +107,15 @@ async def insert( :param set_name: The name of the set to which the record belongs. Defaults to None. :type set_name: Optional[str] - :param ignore_mem_queue_full: Ignore the in-memory queue full error. These records would be written to storage - and later, the index healer would pick for indexing. Defaults to False. - :type dimensions: int + :param ignore_mem_queue_full: Ignore the in-memory queue full error. These records will be written to storage + and later, the index healer will pick them for indexing. Defaults to False. + :type ignore_mem_queue_full: int :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to insert a vector.. + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to insert a vector. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ @@ -168,15 +168,15 @@ async def update( :param set_name: The name of the set to which the record belongs. Defaults to None. :type set_name: Optional[str] - :param ignore_mem_queue_full: Ignore the in-memory queue full error. These records would be written to storage - and later, the index healer would pick for indexing. Defaults to False. - :type dimensions: int + :param ignore_mem_queue_full: Ignore the in-memory queue full error. These records will be written to storage + and later, the index healer will pick them for indexing. Defaults to False. + :type ignore_mem_queue_full: int :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to update a vector.. + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to update a vector. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ @@ -229,15 +229,15 @@ async def upsert( :param set_name: The name of the set to which the record belongs. Defaults to None. :type set_name: Optional[str] - :param ignore_mem_queue_full: Ignore the in-memory queue full error. These records would be written to storage - and later, the index healer would pick for indexing. Defaults to False. - :type dimensions: int + :param ignore_mem_queue_full: Ignore the in-memory queue full error. These records will be written to storage + and later, the index healer will pick them for indexing. Defaults to False. + :type ignore_mem_queue_full: int :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to upsert a vector.. + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to upsert a vector. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ @@ -288,13 +288,13 @@ async def get( :type set_name: Optional[str] :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Returns: types.RecordWithKey: A record with its associated key. Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get a vector.. + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get a vector. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ @@ -336,14 +336,14 @@ async def exists( :type set_name: str :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Returns: bool: True if the record exists, False otherwise. Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to see if a given vector exists.. + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to see if a given vector exists. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ @@ -385,10 +385,10 @@ async def delete( :type set_name: Optional[str] :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to delete a record. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ @@ -435,13 +435,13 @@ async def is_indexed( :type set_name: optional[str] :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Returns: bool: True if the record is indexed, False otherwise. Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to see if an record is indexed. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ @@ -497,13 +497,13 @@ async def vector_search( :type field_names: Optional[list[str]] :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Returns: list[types.Neighbor]: A list of neighbors records found by the search. Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to vector search. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ await self._channel_provider._is_ready() @@ -560,7 +560,7 @@ async def wait_for_index_completion( Raises: Exception: Raised when the timeout occurs while waiting for index completion. - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to wait for index completion. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. Note: diff --git a/src/aerospike_vector_search/aio/internal/channel_provider.py b/src/aerospike_vector_search/aio/internal/channel_provider.py index b237f774..4de7ad98 100644 --- a/src/aerospike_vector_search/aio/internal/channel_provider.py +++ b/src/aerospike_vector_search/aio/internal/channel_provider.py @@ -18,6 +18,7 @@ logger = logging.getLogger(__name__) +TEND_INTERVAL = 1 class ChannelProvider(base_channel_provider.BaseChannelProvider): """AVS Channel Provider""" @@ -47,157 +48,190 @@ def __init__( service_config_path, ) - self._tend_initalized: asyncio.Event = asyncio.Event() + # When set, client has concluded cluster tending self._tend_ended: asyncio.Event = asyncio.Event() - self._task: Optional[asyncio.Task] = None - asyncio.create_task(self._tend()) + # When set, client has completed a cluster tend cycle, initialized auth, and verified client-server minimum compatibility + self._ready: asyncio.Event = asyncio.Event() - async def close(self): - self._closed = True - await self._tend_ended.wait() + # When locked, new task is being assigned to _auth_task + self._auth_tending_lock: asyncio.Lock = asyncio.Lock() - for channel in self._seedChannels: - await channel.close() + # initializes authentication tending + self._auth_task: Optional[asyncio.Task] = asyncio.create_task(self._tend_token()) - for k, channelEndpoints in self._node_channels.items(): - if channelEndpoints.channel: - await channelEndpoints.channel.close() + # initializes client tending processes + asyncio.create_task(self._tend()) - if self._task != None: - await self._task + # Exception to progotate to main control flow from errors generated during tending + self._tend_exception: Exception = None async def _is_ready(self): - if not self.client_server_compatible: - await self._tend_initalized.wait() + # Wait 1 round of cluster tending, auth token initialization, and server client compatiblity verfication + await self._ready.wait() - if not self.client_server_compatible: - raise types.AVSClientError( - message="This AVS Client version is only compatbile with AVS Servers above the following version number: " - + self.minimum_required_version - ) + # This propogates any fatal/unexpected errors from client initialization/tending to the client. + # Raising errors in a task does not deliver this error information to users + if self._tend_exception: + raise self._tend_exception async def _tend(self): + try: - (temp_endpoints, update_endpoints_stub, channels, end_tend) = ( - self.init_tend() - ) - if self._token: - if self._check_if_token_refresh_needed(): - await self._update_token_and_ttl() + await self._auth_task - if end_tend: + # verfies server is minimally compatible with client + await self._check_server_version() - if not self.client_server_compatible: + await self._tend_cluster() - stub = vector_db_pb2_grpc.AboutServiceStub(self.get_channel()) - about_request = vector_db_pb2.AboutRequest() + self._ready.set() - self.current_server_version = ( - await stub.Get(about_request, credentials=self._token) - ).version - self.client_server_compatible = self.verify_compatibile_server() + except Exception as e: + # Set all event to prevent hanging if initial tend fails with error + self._tend_ended.set() + self._ready.set() - self._tend_initalized.set() + async def _tend_cluster(self): + try: + (channels, end_tend_cluster) = self.init_tend_cluster() + + if end_tend_cluster: self._tend_ended.set() return - stubs = [] - tasks = [] - update_endpoints_stubs = [] - for channel in channels: - - stub = vector_db_pb2_grpc.ClusterInfoServiceStub(channel) - stubs.append(stub) - try: - tasks.append( - await stub.GetClusterId(empty, credentials=self._token) - ) - except Exception as e: - logger.debug( - "While tending, failed to get cluster id with error: " + str(e) - ) + (cluster_info_stubs, tasks) = self._gather_new_cluster_ids_and_cluster_info_stubs(channels) - try: - new_cluster_ids = tasks - except Exception as e: - logger.debug( - "While tending, failed to gather results from GetClusterId: " - + str(e) - ) + new_cluster_ids = await asyncio.gather(*tasks) + + update_endpoints_stubs = self._gather_stubs_for_endpoint_updating(new_cluster_ids, cluster_info_stubs) + + tasks = self._gather_temp_endpoints(new_cluster_ids, update_endpoints_stubs) - for index, value in enumerate(new_cluster_ids): - if self.check_cluster_id(value.id): - update_endpoints_stubs.append(stubs[index]) - - for stub in update_endpoints_stubs: - try: - response = await stub.GetClusterEndpoints( - vector_db_pb2.ClusterNodeEndpointsRequest( - listenerName=self.listener_name - ), - credentials=self._token, - ) - temp_endpoints = self.update_temp_endpoints( - response, temp_endpoints - ) - except Exception as e: - logger.debug( - "While tending, failed to get cluster endpoints with error: " - + str(e) - ) - - tasks = [] + cluster_endpoints_list = await asyncio.gather(*tasks) + + temp_endpoints = self._assign_temporary_endpoints(cluster_endpoints_list) if update_endpoints_stubs: - for node, newEndpoints in temp_endpoints.items(): - (channel_endpoints, add_new_channel) = self.check_for_new_endpoints( - node, newEndpoints - ) - - if add_new_channel: - try: - # TODO: Wait for all calls to drain - tasks.append(channel_endpoints.channel.close()) - except Exception as e: - logger.debug( - "While tending, failed to close GRPC channel while replacing up old endpoints: " - + str(e) - ) - self.add_new_channel_to_node_channels(node, newEndpoints) - - for node, channel_endpoints in list(self._node_channels.items()): - if not temp_endpoints.get(node): - try: - # TODO: Wait for all calls to drain - tasks.append(channel_endpoints.channel.close()) - del self._node_channels[node] - - except Exception as e: - logger.debug( - "While tending, failed to close GRPC channel while removing unused endpoints: " - + str(e) - ) + + tasks = self._add_new_channels_from_temp_endpoints(temp_endpoints) + + await asyncio.gather(*tasks) + + tasks = self._close_old_channels_from_node_channels(temp_endpoints) await asyncio.gather(*tasks) - if not self.client_server_compatible: + await asyncio.sleep(TEND_INTERVAL) - (stub, about_request) = self._prepare_about() - self.current_server_version = ( - await stub.Get(about_request, credentials=self._token) - ).version + asyncio.create_task(self._tend_cluster()) - self.client_server_compatible = self.verify_compatibile_server() - self._tend_initalized.set() + except Exception as e: + logger.error("Unexpected tend failure: %s", e) + self._tend_exception = e + raise e + + async def _get_cluster_id_coroutine(self, stub): + try: + return await stub.GetClusterId(empty, credentials=self._token,) + except Exception as e: + logger.debug( + "While tending, failed to get cluster id with error: " + str(e) + ) - # TODO: check tend interval. - await asyncio.sleep(1) - self._task = asyncio.create_task(self._tend()) + async def _get_cluster_endpoints_coroutine(self, stub): + try: + return ( + await stub.GetClusterEndpoints( + vector_db_pb2.ClusterNodeEndpointsRequest( + listenerName=self.listener_name + ), + credentials=self._token, + ) + ).endpoints + except Exception as e: + logger.debug( + "While tending, failed to get cluster endpoints with error: " + + str(e) + ) + + async def _close_on_channel_coroutine(self, channel_endpoints): + try: + await channel_endpoints.channel.close() except Exception as e: - logger.error("Tending failed at unindentified location: %s", e) + logger.debug( + "While tending, failed to close GRPC channel: " + + str(e) + ) + + def _call_get_cluster_id(self, stub): + return asyncio.create_task(self._get_cluster_id_coroutine(stub)) + + def _call_get_cluster_endpoints(self, stub): + return asyncio.create_task(self._get_cluster_endpoints_coroutine(stub)) + + def _call_close_on_channel(self, channel_endpoints): + return asyncio.create_task(self._close_on_channel_coroutine(channel_endpoints)) + + async def _tend_token(self): + try: + if not self._token: + return + elif self._token != True: + await asyncio.sleep((self._ttl * self._ttl_threshold)) + + await self._update_token_and_ttl() + + async with self._auth_tending_lock: + self._auth_task = asyncio.create_task(self._tend_token()) + + except Exception as e: + self._tend_exception = e + logger.error("Failed to tend token with error: %s", e) + raise e + + async def _update_token_and_ttl( + self, + ) -> None: + + (auth_stub, auth_request) = self._prepare_authenticate( + self._credentials, logger + ) + + try: + response = await auth_stub.Authenticate(auth_request) + except grpc.RpcError as e: + logger.error("Failed to refresh authentication token with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + self._respond_authenticate(response.token) + + async def _check_server_version(self): + try: + stub = vector_db_pb2_grpc.AboutServiceStub(self.get_channel()) + about_request = vector_db_pb2.AboutRequest() + + try: + response = await stub.Get( + about_request, credentials=self._token + ) + self.current_server_version = response.version + except grpc.RpcError as e: + logger.debug( + "Failed to retrieve server version: " + + str(e) + ) + self._tend_exception = AVSServerError(rpc_error=e) + self.verify_compatibile_server() + + except Exception as e: + logger.debug( + "Failed to retrieve server version: " + + str(e) + ) + self._tend_exception = e raise e def _create_channel(self, host: str, port: int) -> grpc.Channel: @@ -224,17 +258,19 @@ def _create_channel(self, host: str, port: int) -> grpc.Channel: else: return grpc.aio.insecure_channel(f"{host}:{port}", options=options) - async def _update_token_and_ttl( - self, - ) -> None: - (auth_stub, auth_request) = self._prepare_authenticate( - self._credentials, logger - ) + async def close(self): + # signals to tend_cluster to end cluster tending + self._closed = True - try: - response = await auth_stub.Authenticate(auth_request) - except grpc.RpcError as e: - logger.error("Failed to refresh authentication token with error: %s", e) - raise types.AVSServerError(rpc_error=e) + # wait until cluster tending has ended + await self._tend_ended.wait() - self._respond_authenticate(response.token) + for channel in self._seedChannels: + await channel.close() + + for k, channelEndpoints in self._node_channels.items(): + if channelEndpoints.channel: + await channelEndpoints.channel.close() + + async with self._auth_tending_lock: + self._auth_task.cancel() \ No newline at end of file diff --git a/src/aerospike_vector_search/client.py b/src/aerospike_vector_search/client.py index 54744440..aa6b0f40 100644 --- a/src/aerospike_vector_search/client.py +++ b/src/aerospike_vector_search/client.py @@ -20,7 +20,7 @@ class Client(BaseClient): Moreover, the client supports Hierarchical Navigable Small World (HNSW) vector searches, allowing users to find vectors similar to a given query vector within an index. - :param seeds: Defines the Aerospike Database cluster nodes to which you want AVS to connect. AVS iterates through the seed nodes. After connecting to a node, AVS discovers all of the nodes in the cluster. + :param seeds: Defines the AVS nodes to which you want AVS to connect. AVS iterates through the seed nodes. After connecting to a node, AVS discovers all of the nodes in the cluster. :type seeds: Union[types.HostPort, tuple[types.HostPort, ...]] :param listener_name: An external (NATed) address and port combination that differs from the actual address and port where AVS is listening. Clients can access AVS on a node using the advertised listener address and port. Defaults to None. @@ -41,10 +41,10 @@ class Client(BaseClient): :param root_certificate: The PEM-encoded root certificates as a byte string. Defaults to None. :type root_certificate: Optional[list[bytes], bytes] - :param certificate_chain: The PEM-encoded private key as a byte string. Defaults to None. + :param certificate_chain: The PEM-encoded certificate chain as a byte string. Defaults to None. :type certificate_chain: Optional[bytes] - :param private_key: The PEM-encoded certificate chain as a byte string. Defaults to None. + :param private_key: The PEM-encoded private key as a byte string. Defaults to None. :type private_key: Optional[bytes] :raises AVSClientError: Raised when no seed host is provided. @@ -106,16 +106,15 @@ def insert( :param set_name: The name of the set to which the record belongs. Defaults to None. :type set_name: Optional[str] - :param ignore_mem_queue_full: Ignore the in-memory queue full error. These records would be written to storage - and later, the index healer would pick for indexing. Defaults to False. - :type dimensions: int + :param ignore_mem_queue_full: Ignore the in-memory queue full error. These records will be written to storage + and later, the index healer will pick them for indexing. Defaults to False. + :type ignore_mem_queue_full: int :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int - + :type timeout: int Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to insert a vector.. + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to insert a vector. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ @@ -163,15 +162,15 @@ def update( :param set_name: The name of the set to which the record belongs. Defaults to None. :type set_name: Optional[str] - :param ignore_mem_queue_full: Ignore the in-memory queue full error. These records would be written to storage - and later, the index healer would pick for indexing. Defaults to False. - :type dimensions: int + :param ignore_mem_queue_full: Ignore the in-memory queue full error. These records will be written to storage + and later, the index healer will pick them for indexing. Defaults to False. + :type ignore_mem_queue_full: int :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to update a vector.. + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to update a vector. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ @@ -221,15 +220,15 @@ def upsert( :param set_name: The name of the set to which the record belongs. Defaults to None. :type set_name: Optional[str] - :param ignore_mem_queue_full: Ignore the in-memory queue full error. These records would be written to storage - and later, the index healer would pick for indexing. Defaults to False. - :type dimensions: int + :param ignore_mem_queue_full: Ignore the in-memory queue full error. These records will be written to storage + and later, the index healer will pick them for indexing. Defaults to False. + :type ignore_mem_queue_full: int :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to upsert a vector.. + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to upsert a vector. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ @@ -279,13 +278,13 @@ def get( :type set_name: Optional[str] :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Returns: types.RecordWithKey: A record with its associated key. Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get a vector.. + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get a vector. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ @@ -324,13 +323,13 @@ def exists( :type set_name: str :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Returns: bool: True if the record exists, False otherwise. Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to see if a given vector exists.. + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to see if a given vector exists. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ @@ -369,10 +368,10 @@ def delete( :type set_name: Optional[str] :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to delete the index. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ @@ -417,13 +416,13 @@ def is_indexed( :type set_name: optional[str] :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Returns: bool: True if the record is indexed, False otherwise. Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to check if the vector is indexed. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ @@ -477,13 +476,13 @@ def vector_search( :type field_names: Optional[list[str]] :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type dimensions: int + :type timeout: int Returns: list[types.Neighbor]: A list of neighbors records found by the search. Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to vector search. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. """ @@ -539,7 +538,7 @@ def wait_for_index_completion( Raises: Exception: Raised when the timeout occurs while waiting for index completion. - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to wait for index completion. This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. Note: diff --git a/src/aerospike_vector_search/internal/channel_provider.py b/src/aerospike_vector_search/internal/channel_provider.py index 4354c46c..8d28dc4b 100644 --- a/src/aerospike_vector_search/internal/channel_provider.py +++ b/src/aerospike_vector_search/internal/channel_provider.py @@ -16,6 +16,7 @@ logger = logging.getLogger(__name__) +TEND_INTERVAL = 1 class ChannelProvider(base_channel_provider.BaseChannelProvider): """Proximus Channel Provider""" @@ -43,151 +44,138 @@ def __init__( private_key, service_config_path, ) + # When set, client has concluded cluster tending self._tend_ended = threading.Event() - self._timer = None - self._tend() - def close(self): - self._closed = True - self._tend_ended.wait() + # When locked, new task is being assigned to _auth_task + self._auth_tending_lock: threading.Lock = threading.Lock() - for channel in self._seedChannels: - channel.close() + self._auth_timer = None - for k, channelEndpoints in self._node_channels.items(): - if channelEndpoints.channel: - channelEndpoints.channel.close() + # initializes authentication tending + self._tend_token() - if self._timer != None: - self._timer.join() + # verfies server is minimally compatible with client + self._check_server_version() - def _tend(self): - try: - (temp_endpoints, update_endpoints_stub, channels, end_tend) = ( - self.init_tend() - ) - if self._token: - if self._check_if_token_refresh_needed(): - self._update_token_and_ttl() + # initializes cluster tending + self._tend_cluster() - if end_tend: - if not self.client_server_compatible: - stub = vector_db_pb2_grpc.AboutServiceStub(self.get_channel()) - about_request = vector_db_pb2.AboutRequest() - self.current_server_version = stub.Get( - about_request, credentials=self._token - ).version - self.client_server_compatible = self.verify_compatibile_server() - if not self.client_server_compatible: - self._tend_ended.set() - raise types.AVSClientError( - message="This AVS Client version is only compatbile with AVS Servers above the following version number: " - + self.minimum_required_version - ) - self._tend_ended.set() + def _tend_cluster(self): + try: + (channels, end_tend_cluster) = ( + self.init_tend_cluster() + ) + if end_tend_cluster: + self._tend_ended.set() return - update_endpoints_stubs = [] - new_cluster_ids = [] - for channel in channels: - - stubs = [] - - stub = vector_db_pb2_grpc.ClusterInfoServiceStub(channel) - stubs.append(stub) - - try: - new_cluster_ids.append( - stub.GetClusterId(empty, credentials=self._credentials) - ) - - except Exception as e: - logger.debug( - "While tending, failed to get cluster id with error: " + str(e) - ) - - for index, value in enumerate(new_cluster_ids): - if self.check_cluster_id(value.id): - update_endpoints_stubs.append(stubs[index]) - - for stub in update_endpoints_stubs: - - try: - response = stub.GetClusterEndpoints( - vector_db_pb2.ClusterNodeEndpointsRequest( - listenerName=self.listener_name, - credentials=self._credentials, - ) - ) - temp_endpoints = self.update_temp_endpoints( - response, temp_endpoints - ) - except Exception as e: - logger.debug( - "While tending, failed to get cluster endpoints with error: " - + str(e) - ) + (cluster_info_stubs, new_cluster_ids) = self._gather_new_cluster_ids_and_cluster_info_stubs(channels) + + update_endpoints_stubs = self._gather_stubs_for_endpoint_updating(new_cluster_ids, cluster_info_stubs) + + cluster_endpoints_list = self._gather_temp_endpoints(new_cluster_ids, update_endpoints_stubs) + + temp_endpoints = self._assign_temporary_endpoints(cluster_endpoints_list) if update_endpoints_stubs: - for node, newEndpoints in temp_endpoints.items(): - (channel_endpoints, add_new_channel) = self.check_for_new_endpoints( - node, newEndpoints - ) - - if add_new_channel: - try: - # TODO: Wait for all calls to drain - channel_endpoints.channel.close() - except Exception as e: - logger.debug( - "While tending, failed to close GRPC channel while replacing up old endpoints:" - + str(e) - ) - - self.add_new_channel_to_node_channels(node, newEndpoints) - - for node, channel_endpoints in list(self._node_channels.items()): - if not self._node_channels.get(node): - try: - # TODO: Wait for all calls to drain - channel_endpoints.channel.close() - del self._node_channels[node] - - except Exception as e: - logger.debug( - "While tending, failed to close GRPC channel while removing unused endpoints: " - + str(e) - ) - - if not self.client_server_compatible: - - (stub, about_request) = self._prepare_about() - - try: - self.current_server_version = stub.Get( - about_request, credentials=self._token - ).version - except Exception as e: - logger.debug( - "While tending, failed to close GRPC channel while removing unused endpoints: " - + str(e) - ) - self.client_server_compatible = self.verify_compatibile_server() - if not self.client_server_compatible: - raise types.AVSClientError( - message="This AVS Client version is only compatbile with AVS Servers above the following version number: " - + self.minimum_required_version - ) - - self._timer = threading.Timer(1, self._tend).start() + + self._add_new_channels_from_temp_endpoints(temp_endpoints) + + self._close_old_channels_from_node_channels(temp_endpoints) + + + threading.Timer(TEND_INTERVAL, self._tend_cluster).start() + except Exception as e: logger.error("Tending failed at unindentified location: %s", e) raise e + + def _call_get_cluster_id(self, stub): + try: + return stub.GetClusterId(empty, credentials=self._token,) + except Exception as e: + logger.debug( + "While tending, failed to get cluster id with error: " + str(e) + ) + + def _call_get_cluster_endpoints(self, stub): + try: + return ( + stub.GetClusterEndpoints( + vector_db_pb2.ClusterNodeEndpointsRequest( + listenerName=self.listener_name + ), + credentials=self._token, + ) + ).endpoints + except Exception as e: + logger.debug( + "While tending, failed to get cluster endpoints with error: " + + str(e) + ) + + def _call_close_on_channel(self, channel_endpoints): + try: + channel_endpoints.channel.close() + except Exception as e: + logger.debug( + "While tending, failed to close GRPC channel: " + + str(e) + ) + + def _tend_token(self): + + if not self._token: + return + + self._update_token_and_ttl() + + with self._auth_tending_lock: + self._auth_timer = threading.Timer((self._ttl * self._ttl_threshold), self._tend_token) + self._auth_timer.start() + + + + def _update_token_and_ttl( + self, + ) -> None: + + (auth_stub, auth_request) = self._prepare_authenticate( + self._credentials, logger + ) + + try: + response = auth_stub.Authenticate(auth_request) + except grpc.RpcError as e: + logger.error("Failed to refresh authentication token with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + self._respond_authenticate(response.token) + + def _check_server_version(self): + stub = vector_db_pb2_grpc.AboutServiceStub(self.get_channel()) + about_request = vector_db_pb2.AboutRequest() + + try: + response = stub.Get( + about_request, credentials=self._token + ) + self.current_server_version = response.version + except grpc.RpcError as e: + logger.debug( + "Failed to retrieve server version: " + + str(e) + ) + raise AVSServerError(rpc_error=e) + self.verify_compatibile_server() + + def _create_channel(self, host: str, port: int) -> grpc.Channel: host = re.sub(r"%.*", "", host) @@ -213,18 +201,17 @@ def _create_channel(self, host: str, port: int) -> grpc.Channel: else: return grpc.insecure_channel(f"{host}:{port}", options=options) - def _update_token_and_ttl( - self, - ) -> None: + def close(self): + self._closed = True + self._tend_ended.wait() - (auth_stub, auth_request) = self._prepare_authenticate( - self._credentials, logger - ) + for channel in self._seedChannels: + channel.close() - try: - response = auth_stub.Authenticate(auth_request) - except grpc.RpcError as e: - logger.error("Failed to refresh authentication token with error: %s", e) - raise types.AVSServerError(rpc_error=e) + for k, channelEndpoints in self._node_channels.items(): + if channelEndpoints.channel: + channelEndpoints.channel.close() - self._respond_authenticate(response.token) + with self._auth_tending_lock: + if self._auth_timer != None: + self._auth_timer.cancel() diff --git a/src/aerospike_vector_search/shared/admin_helpers.py b/src/aerospike_vector_search/shared/admin_helpers.py index b54851cc..2349efea 100644 --- a/src/aerospike_vector_search/shared/admin_helpers.py +++ b/src/aerospike_vector_search/shared/admin_helpers.py @@ -325,4 +325,4 @@ def _prepare_wait_for_index_waiting(self, namespace, name, wait_interval): def _check_timeout(self, start_time, timeout): if start_time + timeout < time.monotonic(): - raise "timed-out waiting for index creation" + raise AVSClientError(message="timed-out waiting for index creation") diff --git a/src/aerospike_vector_search/shared/base_channel_provider.py b/src/aerospike_vector_search/shared/base_channel_provider.py index 2e065ac4..ec61cebf 100644 --- a/src/aerospike_vector_search/shared/base_channel_provider.py +++ b/src/aerospike_vector_search/shared/base_channel_provider.py @@ -124,23 +124,18 @@ def add_new_channel_to_node_channels(self, node, newEndpoints): new_channel = self._create_channel_from_server_endpoint_list(newEndpoints) self._node_channels[node] = ChannelAndEndpoints(new_channel, newEndpoints) - def init_tend(self) -> None: - end_tend = False - if self._is_loadbalancer: - # Skip tend if we are behind a load-balancer - end_tend = True - - if self._closed: - end_tend = True + def init_tend_cluster(self) -> None: - # TODO: Worry about thread safety - temp_endpoints: dict[int, vector_db_pb2.ServerEndpointList] = {} + end_tend_cluster = False + if self._is_loadbalancer or self._closed: + # Skip tend if we are behind a load-balancer + end_tend_cluster = True - update_endpoints_stub = None channels = self._seedChannels + [ x.channel for x in self._node_channels.values() ] - return (temp_endpoints, update_endpoints_stub, channels, end_tend) + + return (channels, end_tend_cluster) def check_cluster_id(self, new_cluster_id) -> None: if new_cluster_id == self._cluster_id: @@ -150,8 +145,7 @@ def check_cluster_id(self, new_cluster_id) -> None: return True - def update_temp_endpoints(self, response, temp_endpoints): - endpoints = response.endpoints + def update_temp_endpoints(self, endpoints, temp_endpoints): if len(endpoints) > len(temp_endpoints): return endpoints else: @@ -160,6 +154,7 @@ def update_temp_endpoints(self, response, temp_endpoints): def check_for_new_endpoints(self, node, newEndpoints): channel_endpoints = self._node_channels.get(node) + add_new_channel = True if channel_endpoints: @@ -173,13 +168,6 @@ def check_for_new_endpoints(self, node, newEndpoints): def _get_ttl(self, payload): return payload["exp"] - payload["iat"] - def _check_if_token_refresh_needed(self): - if self._token and (time.time() - self._ttl_start) > ( - self._ttl * self._ttl_threshold - ): - return True - return False - def _prepare_authenticate(self, credentials, logger): logger.debug("Refreshing auth token") auth_stub = self._get_auth_stub() @@ -205,8 +193,93 @@ def _respond_authenticate(self, token): def verify_compatibile_server(self) -> bool: def parse_version(v: str): - return tuple(int(part) if part.isdigit() else part for part in v.split(".")) + return tuple(str(part) if part.isdigit() else part for part in v.split(".")) + if parse_version(self.current_server_version) < parse_version(self.minimum_required_version): + self._tend_ended.set() + raise types.AVSClientError( + message="This AVS Client version is only compatbile with AVS Servers above the following version number: " + + self.minimum_required_version + ) + else: + self.client_server_compatible = True - return parse_version(self.current_server_version) >= parse_version( - self.minimum_required_version - ) + def _gather_new_cluster_ids_and_cluster_info_stubs(self, channels): + + stubs = [] + responses = [] + + for channel in channels: + + stub = vector_db_pb2_grpc.ClusterInfoServiceStub(channel) + stubs.append(stub) + + response = self._call_get_cluster_id(stub) + responses.append(response) + + return (stubs, responses) + + def _gather_stubs_for_endpoint_updating(self, new_cluster_ids, cluster_info_stubs): + update_endpoints_stubs = [] + for index, value in enumerate(new_cluster_ids): + + if self.check_cluster_id(value.id): + update_endpoints_stub = cluster_info_stubs[index] + + update_endpoints_stubs.append(update_endpoints_stub) + return update_endpoints_stubs + + def _gather_temp_endpoints(self, new_cluster_ids, update_endpoints_stubs): + + responses = [] + for stub in update_endpoints_stubs: + response = self._call_get_cluster_endpoints(stub) + + + responses.append(response) + return responses + + def _assign_temporary_endpoints(self, cluster_endpoints_list): + # TODO: Worry about thread safety + temp_endpoints: dict[int, vector_db_pb2.ServerEndpointList] = {} + for endpoints in cluster_endpoints_list: + temp_endpoints = self.update_temp_endpoints( + endpoints, temp_endpoints + ) + return temp_endpoints + + + def _add_new_channels_from_temp_endpoints(self, temp_endpoints): + responses = [] + + + for node, newEndpoints in temp_endpoints.items(): + + # Compare node channel result + (channel_endpoints, add_new_channel) = self.check_for_new_endpoints( + node, newEndpoints + ) + + + if add_new_channel: + if channel_endpoints: + response = self._call_close_on_channel(channel_endpoints) + responses.append(response) + + self.add_new_channel_to_node_channels(node, newEndpoints) + + return responses + + def _close_old_channels_from_node_channels(self, temp_endpoints): + responses = [] + + + for node, channel_endpoints in list(self._node_channels.items()): + if not temp_endpoints.get(node): + # TODO: Wait for all calls to drain + response = self._call_close_on_channel(channel_endpoints) + responses.append(response) + + + del self._node_channels[node] + + return responses \ No newline at end of file diff --git a/src/aerospike_vector_search/shared/client_helpers.py b/src/aerospike_vector_search/shared/client_helpers.py index 97db6c96..e13ad13e 100644 --- a/src/aerospike_vector_search/shared/client_helpers.py +++ b/src/aerospike_vector_search/shared/client_helpers.py @@ -354,9 +354,14 @@ def _prepare_wait_for_index_waiting(self, namespace, name, wait_interval): self, namespace, name, wait_interval ) + def _check_timeout(self, start_time, timeout): + if start_time + timeout < time.monotonic(): + raise AVSClientError(message="timed-out waiting for index creation") + def _check_completion_condition( self, start_time, timeout, index_status, unmerged_record_initialized ): + self._check_timeout(start_time, timeout) if start_time + 10 < time.monotonic(): unmerged_record_initialized = True diff --git a/src/aerospike_vector_search/shared/conversions.py b/src/aerospike_vector_search/shared/conversions.py index a90a70a8..6fa984aa 100644 --- a/src/aerospike_vector_search/shared/conversions.py +++ b/src/aerospike_vector_search/shared/conversions.py @@ -97,9 +97,9 @@ def fromIndexDefintion(input: types_pb2.IndexDefinition) -> types.IndexDefinitio name=input.id.name, ), dimensions=input.dimensions, - # vector_distance_metric=input.vectorDistanceMetric, + vector_distance_metric=input.vectorDistanceMetric, field=input.field, - # sets=input.setFilter + sets=input.setFilter, hnsw_params=types.HnswParams( m=input.hnswParams.m, ef_construction=input.hnswParams.efConstruction, @@ -108,22 +108,22 @@ def fromIndexDefintion(input: types_pb2.IndexDefinition) -> types.IndexDefinitio max_records=input.hnswParams.batchingParams.maxRecords, interval=input.hnswParams.batchingParams.interval, ), - # caching_params=types.HnswCachingParams( - # max_entries=input.hnswParams.cachingParams.maxEntries, - # expiry=input.hnswParams.cachingParams.expiry, - # ), - # healer_params=types.HnswHealerParams( - # max_scan_rate_per_node=input.hnswParams.healerParams.maxScanRatePerNode, - # max_scan_page_ize=input.hnswParams.healerParams.maxScanPageSize, - # re_index_percent=input.hnswParams.healerParams.reindexPercent, - # schedule_delay=input.hnswParams.healerParams.scheduleDelay, - # parallelism=input.hnswParams.healerParams.parallelism - # ), - # merge_params=types.HnswMergeParams( - # parallelism=input.hnswParams.mergeParams.parallelism - # ) + caching_params=types.HnswCachingParams( + max_entries=input.hnswParams.cachingParams.maxEntries, + expiry=input.hnswParams.cachingParams.expiry, + ), + healer_params=types.HnswHealerParams( + max_scan_rate_per_node=input.hnswParams.healerParams.maxScanRatePerNode, + max_scan_page_size=input.hnswParams.healerParams.maxScanPageSize, + re_index_percent=input.hnswParams.healerParams.reindexPercent, + schedule_delay=input.hnswParams.healerParams.scheduleDelay, + parallelism=input.hnswParams.healerParams.parallelism + ), + merge_params=types.HnswIndexMergeParams( + parallelism=input.hnswParams.mergeParams.parallelism + ) ), - # labels=input.labels, + index_labels=input.labels, storage=types.IndexStorage( namespace=input.storage.namespace, set_name=input.storage.set ), diff --git a/src/aerospike_vector_search/types.py b/src/aerospike_vector_search/types.py index 81ea13e0..2b882a57 100644 --- a/src/aerospike_vector_search/types.py +++ b/src/aerospike_vector_search/types.py @@ -224,23 +224,25 @@ class HnswBatchingParams(object): """ Parameters for configuring batching behaviour for batch based index update. - :param max_records: Maximum number of records to fit in a batch. Defaults to 10000. - :param interva: The maximum amount of time in milliseconds to wait before finalizing a batch. Defaults to 10000. + :param max_records: Maximum number of records to fit in a batch. Defaults to server default.. + :param interval: The maximum amount of time in milliseconds to wait before finalizing a batch. Defaults to server default.. """ def __init__( self, *, - max_records: Optional[int] = 10000, - interval: Optional[int] = 10000, + max_records: Optional[int] = None, + interval: Optional[int] = None, ) -> None: self.max_records = max_records self.interval = interval def _to_pb2(self): params = types_pb2.HnswBatchingParams() - params.maxRecords = self.max_records - params.interval = self.interval + if self.max_records: + params.maxRecords = self.max_records + if self.interval: + params.interval = self.interval return params def __repr__(self) -> str: @@ -400,9 +402,9 @@ class HnswParams(object): def __init__( self, *, - m: Optional[int] = 16, - ef_construction: Optional[int] = 100, - ef: Optional[int] = 100, + m: Optional[int] = None, + ef_construction: Optional[int] = None, + ef: Optional[int] = None, batching_params: Optional[HnswBatchingParams] = HnswBatchingParams(), max_mem_queue_size: Optional[int] = None, caching_params: Optional[HnswCachingParams] = HnswCachingParams(), @@ -420,9 +422,15 @@ def __init__( def _to_pb2(self): params = types_pb2.HnswParams() - params.m = self.m - params.efConstruction = self.ef_construction - params.ef = self.ef + if self.m: + params.m = self.m + + if self.ef_construction: + params.efConstruction = self.ef_construction + + if self.ef: + params.ef = self.ef + if self.max_mem_queue_size: params.maxMemQueueSize = self.max_mem_queue_size @@ -581,14 +589,31 @@ def __setitem__(self, key, value): class IndexDefinition(object): """ - AVS Index Defintion + AVS Index Definition - :param username: Username associated with user. - :type username: str + :param id: Index ID. + :type id: str - :param roles: roles associated with user. - :type roles: list[str] + :param dimensions: Number of dimensions. + :type dimensions: int + + :param vector_distance_metric: Metric used to evaluate vector searches on the given index + :type vector_distance_metric: VectorDistanceMetric + + :param field: Field name. + :type field: str + + :param sets: Set name + :type sets: str + + :param hnsw_params: HNSW parameters. + :type hnsw_params: HnswParams + + :param storage: Index storage details. + :type storage: IndexStorage + :param index_labels: Meta data associated with the index. Defaults to None. + :type index_labels: Optional[dict[str, str]] """ def __init__( @@ -596,26 +621,34 @@ def __init__( *, id: str, dimensions: int, + vector_distance_metric: types_pb2.VectorDistanceMetric, field: str, + sets: str, hnsw_params: HnswParams, storage: IndexStorage, + index_labels: dict[str, str] ) -> None: self.id = id self.dimensions = dimensions + self.vector_distance_metric = vector_distance_metric self.field = field + self.sets = sets self.hnsw_params = hnsw_params self.storage = storage + self.index_labels = index_labels def __repr__(self) -> str: return ( - f"IndexDefinition(id={self.id!r}, dimensions={self.dimensions}, field={self.field!r}, " - f"hnsw_params={self.hnsw_params!r}, storage={self.storage!r})" + f"IndexDefinition(id={self.id!r}, dimensions={self.dimensions}, field={self.field!r}, sets={self.sets!r}," + f"vector_distance_metric={self.vector_distance_metric!r}, hnsw_params={self.hnsw_params!r}, storage={self.storage!r}, " + f"index_labels={self.index_labels}" ) def __str__(self) -> str: return ( - f"IndexDefinition(id={self.id}, dimensions={self.dimensions}, field={self.field}, " - f"hnsw_params={self.hnsw_params}, storage={self.storage})" + f"IndexDefinition(id={self.id}, dimensions={self.dimensions}, field={self.field}, sets={self.sets!r}, " + f"vector_distance_metric={self.vector_distance_metric}, hnsw_params={self.hnsw_params}, storage={self.storage}, " + f"index_labels={self.index_labels}" ) def __eq__(self, other) -> bool: @@ -624,9 +657,12 @@ def __eq__(self, other) -> bool: return ( self.id == other.id and self.dimensions == other.dimensions + and self.vector_distance_metric == other.vector_distance_metric and self.field == other.field + and self.sets == other.sets and self.hnsw_params == other.hnsw_params and self.storage == other.storage + and self.index_labels == other.index_labels ) def __getitem__(self, key): diff --git a/tests/standard/aio/conftest.py b/tests/standard/aio/conftest.py index 9b9db6e3..8284346c 100644 --- a/tests/standard/aio/conftest.py +++ b/tests/standard/aio/conftest.py @@ -4,6 +4,10 @@ from aerospike_vector_search.aio.admin import Client as AdminClient from aerospike_vector_search import types +#import logging +#logger = logging.getLogger(__name__) +#logging.basicConfig(filename='example.log', encoding='utf-8', level=logging.DEBUG) + @pytest.fixture(scope="module", autouse=True) async def drop_all_indexes( @@ -36,15 +40,16 @@ async def drop_all_indexes( certificate_chain=certificate_chain, private_key=private_key, ) as client: - index_list = await client.index_list() + index_list = await client.index_list() tasks = [] for item in index_list: - tasks.append(client.index_drop(namespace="test", name=item["id"]["name"])) + tasks.append(asyncio.create_task(client.index_drop(namespace="test", name=item["id"]["name"]))) await asyncio.gather(*tasks) + @pytest.fixture(scope="module") async def session_admin_client( host, @@ -56,7 +61,6 @@ async def session_admin_client( private_key, is_loadbalancer, ): - if root_certificate: with open(root_certificate, "rb") as f: root_certificate = f.read() @@ -77,6 +81,7 @@ async def session_admin_client( username=username, password=password, ) + yield client await client.close() diff --git a/tests/standard/aio/test_admin_client_index_create.py b/tests/standard/aio/test_admin_client_index_create.py index 48a43266..16ae1779 100644 --- a/tests/standard/aio/test_admin_client_index_create.py +++ b/tests/standard/aio/test_admin_client_index_create.py @@ -7,6 +7,15 @@ from .aio_utils import drop_specified_index from hypothesis import given, settings, Verbosity, Phase +server_defaults = { + "m": 16, + "ef_construction": 100, + "ef": 100, + "batching_params": { + "max_records": 10000, + "interval": 10000, + } +} class index_create_test_case: def __init__( @@ -344,6 +353,19 @@ async def test_index_create_with_sets(session_admin_client, test_case, random_na dimensions=1024, vector_distance_metric=None, sets=None, + index_params=types.HnswParams( + m=8, + ), + index_labels=None, + index_storage=None, + timeout=None, + ), + index_create_test_case( + namespace="test", + vector_field="example_13", + dimensions=1024, + vector_distance_metric=None, + sets=None, index_params=types.HnswParams( batching_params=types.HnswBatchingParams(max_records=500, interval=500) ), @@ -351,6 +373,27 @@ async def test_index_create_with_sets(session_admin_client, test_case, random_na index_storage=None, timeout=None, ), + index_create_test_case( + namespace="test", + vector_field="example_20", + dimensions=1024, + vector_distance_metric=None, + sets="demo", + index_params=types.HnswParams( + caching_params=types.HnswCachingParams(max_entries=10, expiry=3000), + healer_params=types.HnswHealerParams( + max_scan_rate_per_node=80, + max_scan_page_size=40, + re_index_percent=50, + schedule_delay=5, + parallelism=4, + ), + merge_params=types.HnswIndexMergeParams(parallelism=10), + ), + index_labels=None, + index_storage=None, + timeout=None, + ), ], ) async def test_index_create_with_index_params( @@ -376,19 +419,19 @@ async def test_index_create_with_index_params( assert result["id"]["namespace"] == test_case.namespace assert result["dimensions"] == test_case.dimensions assert result["field"] == test_case.vector_field - assert result["hnsw_params"]["m"] == test_case.index_params.m + assert result["hnsw_params"]["m"] == test_case.index_params.m or server_defaults assert ( result["hnsw_params"]["ef_construction"] - == test_case.index_params.ef_construction + == test_case.index_params.ef_construction or server_defaults ) - assert result["hnsw_params"]["ef"] == test_case.index_params.ef + assert result["hnsw_params"]["ef"] == test_case.index_params.ef or server_defaults assert ( result["hnsw_params"]["batching_params"]["max_records"] - == test_case.index_params.batching_params.max_records + == test_case.index_params.batching_params.max_records or server_defaults ) assert ( result["hnsw_params"]["batching_params"]["interval"] - == test_case.index_params.batching_params.interval + == test_case.index_params.batching_params.interval or server_defaults ) assert result["storage"]["namespace"] == test_case.namespace assert result["storage"]["set_name"] == random_name @@ -403,7 +446,7 @@ async def test_index_create_with_index_params( [ index_create_test_case( namespace="test", - vector_field="example_13", + vector_field="example_14", dimensions=1024, vector_distance_metric=None, sets=None, @@ -455,7 +498,7 @@ async def test_index_create_index_labels(session_admin_client, test_case, random [ index_create_test_case( namespace="test", - vector_field="example_14", + vector_field="example_15", dimensions=1024, vector_distance_metric=None, sets=None, @@ -504,7 +547,7 @@ async def test_index_create_index_storage(session_admin_client, test_case, rando [ index_create_test_case( namespace="test", - vector_field="example_15", + vector_field="example_16", dimensions=1024, vector_distance_metric=None, sets=None, diff --git a/tests/standard/sync/conftest.py b/tests/standard/sync/conftest.py index 4b1323f9..8a00629a 100644 --- a/tests/standard/sync/conftest.py +++ b/tests/standard/sync/conftest.py @@ -1,8 +1,12 @@ import pytest + from aerospike_vector_search import Client from aerospike_vector_search.admin import Client as AdminClient from aerospike_vector_search import types +#import logging +#logger = logging.getLogger(__name__) +#logging.basicConfig(filename='example.log', encoding='utf-8', level=logging.DEBUG) @pytest.fixture(scope="module", autouse=True) def drop_all_indexes( diff --git a/tests/standard/sync/test_admin_client_index_create.py b/tests/standard/sync/test_admin_client_index_create.py index 2bf065ed..a448e7ba 100644 --- a/tests/standard/sync/test_admin_client_index_create.py +++ b/tests/standard/sync/test_admin_client_index_create.py @@ -7,6 +7,15 @@ from .sync_utils import drop_specified_index from hypothesis import given, settings, Verbosity +server_defaults = { + "m": 16, + "ef_construction": 100, + "ef": 100, + "batching_params": { + "max_records": 10000, + "interval": 10000, + } +} class index_create_test_case: def __init__( @@ -358,6 +367,19 @@ def test_index_create_with_sets(session_admin_client, test_case, random_name): index_storage=None, timeout=None, ), + index_create_test_case( + namespace="test", + vector_field="example_20", + dimensions=1024, + vector_distance_metric=None, + sets=None, + index_params=types.HnswParams( + m=8, + ), + index_labels=None, + index_storage=None, + timeout=None, + ), index_create_test_case( namespace="test", vector_field="example_12", @@ -421,25 +443,25 @@ def test_index_create_with_index_params(session_admin_client, test_case, random_ assert result["id"]["namespace"] == test_case.namespace assert result["dimensions"] == test_case.dimensions assert result["field"] == test_case.vector_field - assert result["hnsw_params"]["m"] == test_case.index_params.m + assert result["hnsw_params"]["m"] == test_case.index_params.m or server_defaults assert ( result["hnsw_params"]["ef_construction"] - == test_case.index_params.ef_construction + == test_case.index_params.ef_construction or server_defaults ) - assert result["hnsw_params"]["ef"] == test_case.index_params.ef + assert result["hnsw_params"]["ef"] == test_case.index_params.ef or server_defaults if getattr(result.hnsw_params, 'max_mem_queue_size', None) is not None: assert ( result["hnsw_params"]["max_mem_queue_size"] - == test_case.index_params.max_mem_queue_size + == test_case.index_params.max_mem_queue_size or server_defaults ) assert ( result["hnsw_params"]["batching_params"]["max_records"] - == test_case.index_params.batching_params.max_records + == test_case.index_params.batching_params.max_records or server_defaults ) assert ( result["hnsw_params"]["batching_params"]["interval"] - == test_case.index_params.batching_params.interval + == test_case.index_params.batching_params.interval or server_defaults ) """ if getattr(result.hnsw_params, 'caching_params', None) is not None: