From d43faf70650c312c7c30351f831fb84626ecb6b0 Mon Sep 17 00:00:00 2001 From: dylan Date: Fri, 20 Dec 2024 10:42:25 -0800 Subject: [PATCH 01/21] refactor: propogate index and record fixtures through async tests --- src/aerospike_vector_search/types.py | 27 +- tests/standard/aio/aio_utils.py | 22 +- tests/standard/aio/conftest.py | 59 ++- .../aio/test_admin_client_index_drop.py | 25 +- .../aio/test_admin_client_index_get.py | 56 +-- .../aio/test_admin_client_index_get_status.py | 31 +- .../aio/test_admin_client_index_list.py | 12 +- .../aio/test_admin_client_index_update.py | 38 +- .../standard/aio/test_vector_client_delete.py | 38 +- .../standard/aio/test_vector_client_exists.py | 30 +- tests/standard/aio/test_vector_client_get.py | 170 ++++---- .../standard/aio/test_vector_client_insert.py | 18 +- .../aio/test_vector_client_is_indexed.py | 71 +--- .../aio/test_vector_client_search_by_key.py | 386 +++++++++--------- .../standard/aio/test_vector_client_update.py | 29 +- .../standard/aio/test_vector_client_upsert.py | 16 +- tests/standard/aio/test_vector_search.py | 30 +- tests/utils.py | 7 + 18 files changed, 509 insertions(+), 556 deletions(-) diff --git a/src/aerospike_vector_search/types.py b/src/aerospike_vector_search/types.py index 8cb7b71d..59e3c214 100644 --- a/src/aerospike_vector_search/types.py +++ b/src/aerospike_vector_search/types.py @@ -528,21 +528,36 @@ class HnswParams(object): """ Parameters for the Hierarchical Navigable Small World (HNSW) algorithm, used for approximate nearest neighbor search. - :param m: The number of bi-directional links created per level during construction. Larger 'm' values lead to higher recall but slower construction. Defaults to 16. + :param m: The number of bi-directional links created per level during construction. Larger 'm' values lead to higher recall but slower construction. Optional, Defaults to 16. :type m: Optional[int] - :param ef_construction: The size of the dynamic list for the nearest neighbors (candidates) during the index construction. Larger 'ef_construction' values lead to higher recall but slower construction. Defaults to 100. + :param ef_construction: The size of the dynamic list for the nearest neighbors (candidates) during the index construction. Larger 'ef_construction' values lead to higher recall but slower construction. Optional, Defaults to 100. :type ef_construction: Optional[int] - :param ef: The size of the dynamic list for the nearest neighbors (candidates) during the search phase. Larger 'ef' values lead to higher recall but slower search. Defaults to 100. + :param ef: The size of the dynamic list for the nearest neighbors (candidates) during the search phase. Larger 'ef' values lead to higher recall but slower search. Optional, Defaults to 100. :type ef: Optional[int] - :param batching_params: Parameters related to configuring batch processing, such as the maximum number of records per batch and batching interval. Defaults to HnswBatchingParams(). - :type batching_params: Optional[HnswBatchingParams] + :param batching_params: Parameters related to configuring batch processing, such as the maximum number of records per batch and batching interval. Optional, Defaults to HnswBatchingParams(). + :type batching_params: HnswBatchingParams - :param enable_vector_integrity_check: Verifies if the underlying vector has changed before returning the kANN result. + :param max_mem_queue_size: Maximum size of in-memory queue for inserted/updated vector records. Optional, Defaults to the corresponding config on the AVS Server. + :type max_mem_queue_size: Optional[int] + + :param index_caching_params: Parameters related to configuring caching for the HNSW index. Optional, Defaults to HnswCachingParams(). + :type index_caching_params: HnswCachingParams + + :param healer_params: Parameters related to configuring the HNSW index healer. Optional, Defaults to HnswHealerParams(). + :type healer_params: HnswHealerParams + + :param merge_params: Parameters related to configuring the merging of index records. Optional, Defaults to HnswIndexMergeParams(). + :type merge_params: HnswIndexMergeParams + + :param enable_vector_integrity_check: Verifies if the underlying vector has changed before returning the kANN result. Optional, Defaults to True. :type enable_vector_integrity_check: Optional[bool] + :param record_caching_params: Parameters related to configuring caching for vector records. Optional, Defaults to HnswCachingParams(). + :type record_caching_params: HnswCachingParams + """ def __init__( diff --git a/tests/standard/aio/aio_utils.py b/tests/standard/aio/aio_utils.py index 6dea216e..c07fcf1f 100644 --- a/tests/standard/aio/aio_utils.py +++ b/tests/standard/aio/aio_utils.py @@ -1,5 +1,4 @@ -async def drop_specified_index(admin_client, namespace, name): - await admin_client.index_drop(namespace=namespace, name=name) +import asyncio def gen_records(count: int, vec_bin: str, vec_dim: int): num = 0 @@ -10,3 +9,22 @@ def gen_records(count: int, vec_bin: str, vec_dim: int): ) yield key_and_rec num += 1 + + +async def wait_for_index(admin_client, namespace: str, index: str): + + verticies = 0 + unmerged_recs = 0 + + while verticies == 0 or unmerged_recs > 0: + status = await admin_client.index_get_status( + namespace=namespace, + name=index, + ) + + verticies = status.index_healer_vertices_valid + unmerged_recs = status.unmerged_record_count + + # print(verticies) + # print(unmerged_recs) + await asyncio.sleep(0.5) \ No newline at end of file diff --git a/tests/standard/aio/conftest.py b/tests/standard/aio/conftest.py index 97429f6d..c619b802 100644 --- a/tests/standard/aio/conftest.py +++ b/tests/standard/aio/conftest.py @@ -2,12 +2,14 @@ import pytest import random import string +import grpc from aerospike_vector_search.aio import Client from aerospike_vector_search.aio.admin import Client as AdminClient -from aerospike_vector_search import types +from aerospike_vector_search import types, AVSServerError from .aio_utils import gen_records +import utils #import logging #logger = logging.getLogger(__name__) @@ -15,9 +17,9 @@ # default test values -DEFAULT_NAMESPACE = "test" -DEFAULT_INDEX_DIMENSION = 128 -DEFAULT_VECTOR_FIELD = "vector" +DEFAULT_NAMESPACE = utils.DEFAULT_NAMESPACE +DEFAULT_INDEX_DIMENSION = utils.DEFAULT_INDEX_DIMENSION +DEFAULT_VECTOR_FIELD = utils.DEFAULT_VECTOR_FIELD DEFAULT_INDEX_ARGS = { "namespace": DEFAULT_NAMESPACE, "vector_field": DEFAULT_VECTOR_FIELD, @@ -199,14 +201,24 @@ def index_name(): @pytest.fixture(params=[DEFAULT_INDEX_ARGS]) async def index(session_admin_client, index_name, request): - index_args = request.param + args = request.param + namespace = args.get("namespace", DEFAULT_NAMESPACE) + vector_field = args.get("vector_field", DEFAULT_VECTOR_FIELD) + dimensions = args.get("dimensions", DEFAULT_INDEX_DIMENSION) await session_admin_client.index_create( name = index_name, - **index_args, + namespace = namespace, + vector_field = vector_field, + dimensions = dimensions, ) yield index_name - namespace = index_args.get("namespace", DEFAULT_NAMESPACE) - await session_admin_client.index_drop(namespace=namespace, name=index_name) + try: + await session_admin_client.index_drop(namespace=namespace, name=index_name) + except AVSServerError as se: + if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: + pass + else: + raise @pytest.fixture(params=[DEFAULT_RECORDS_ARGS]) @@ -217,10 +229,35 @@ async def records(session_vector_client, request): num_records = args.get("num_records", DEFAULT_NUM_RECORDS) vector_field = args.get("vector_field", DEFAULT_VECTOR_FIELD) dimensions = args.get("dimensions", DEFAULT_INDEX_DIMENSION) + set_name = args.get("set_name", None) keys = [] for key, rec in record_generator(count=num_records, vec_bin=vector_field, vec_dim=dimensions): - await session_vector_client.upsert(namespace=namespace, key=key, record_data=rec) + await session_vector_client.upsert( + namespace=namespace, + key=key, + record_data=rec, + set_name=set_name, + ) keys.append(key) - yield len(keys) + yield keys for key in keys: - await session_vector_client.delete(key=key, namespace=namespace) \ No newline at end of file + await session_vector_client.delete(key=key, namespace=namespace) + + +@pytest.fixture(params=[DEFAULT_RECORDS_ARGS]) +async def record(session_vector_client, request): + args = request.param + record_generator = args.get("record_generator", DEFAULT_RECORD_GENERATOR) + namespace = args.get("namespace", DEFAULT_NAMESPACE) + vector_field = args.get("vector_field", DEFAULT_VECTOR_FIELD) + dimensions = args.get("dimensions", DEFAULT_INDEX_DIMENSION) + set_name = args.get("set_name", None) + key, rec = next(record_generator(count=1, vec_bin=vector_field, vec_dim=dimensions)) + await session_vector_client.upsert( + namespace=namespace, + key=key, + record_data=rec, + set_name=set_name, + ) + yield key + await session_vector_client.delete(key=key, namespace=namespace) \ No newline at end of file diff --git a/tests/standard/aio/test_admin_client_index_drop.py b/tests/standard/aio/test_admin_client_index_drop.py index b49e473b..9128a5c8 100644 --- a/tests/standard/aio/test_admin_client_index_drop.py +++ b/tests/standard/aio/test_admin_client_index_drop.py @@ -2,7 +2,7 @@ from aerospike_vector_search import AVSServerError import grpc -from ...utils import random_name +from utils import DEFAULT_NAMESPACE from hypothesis import given, settings, Verbosity @@ -11,40 +11,29 @@ @pytest.mark.parametrize("empty_test_case", [None, None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=2000) -async def test_index_drop(session_admin_client, empty_test_case, random_name): - await session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="art", - dimensions=1024, - ) - await session_admin_client.index_drop(namespace="test", name=random_name) +async def test_index_drop(session_admin_client, empty_test_case, index): + await session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=index) result = session_admin_client.index_list() result = await result for index in result: - assert index["id"]["name"] != random_name + assert index["id"]["name"] != index @pytest.mark.parametrize("empty_test_case", [None, None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) async def test_index_drop_timeout( - session_admin_client, empty_test_case, random_name, with_latency + session_admin_client, empty_test_case, index, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") - await session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="art", - dimensions=1024, - ) + with pytest.raises(AVSServerError) as e_info: for i in range(10): await session_admin_client.index_drop( - namespace="test", name=random_name, timeout=0.0001 + namespace=DEFAULT_NAMESPACE, name=index, timeout=0.0001 ) assert e_info.value.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED diff --git a/tests/standard/aio/test_admin_client_index_get.py b/tests/standard/aio/test_admin_client_index_get.py index 5ebeaf1f..b584da8a 100644 --- a/tests/standard/aio/test_admin_client_index_get.py +++ b/tests/standard/aio/test_admin_client_index_get.py @@ -1,7 +1,6 @@ import pytest -from ...utils import random_name +from ...utils import DEFAULT_NAMESPACE, DEFAULT_INDEX_DIMENSION, DEFAULT_VECTOR_FIELD -from .aio_utils import drop_specified_index from hypothesis import given, settings, Verbosity from aerospike_vector_search import AVSServerError @@ -11,19 +10,13 @@ @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -async def test_index_get(session_admin_client, empty_test_case, random_name): - await session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="science", - dimensions=1024, - ) - result = await session_admin_client.index_get(namespace="test", name=random_name, apply_defaults=True) - - assert result["id"]["name"] == random_name - assert result["id"]["namespace"] == "test" - assert result["dimensions"] == 1024 - assert result["field"] == "science" +async def test_index_get(session_admin_client, empty_test_case, index): + result = await session_admin_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=True) + + assert result["id"]["name"] == index + assert result["id"]["namespace"] == DEFAULT_NAMESPACE + assert result["dimensions"] == DEFAULT_INDEX_DIMENSION + assert result["field"] == DEFAULT_VECTOR_FIELD assert result["hnsw_params"]["m"] == 16 assert result["hnsw_params"]["ef_construction"] == 100 assert result["hnsw_params"]["ef"] == 100 @@ -31,9 +24,9 @@ async def test_index_get(session_admin_client, empty_test_case, random_name): assert result["hnsw_params"]["batching_params"]["index_interval"] == 30000 assert result["hnsw_params"]["batching_params"]["max_reindex_records"] == max(100000 / 10, 1000) assert result["hnsw_params"]["batching_params"]["reindex_interval"] == 30000 - assert result["storage"]["namespace"] == "test" - assert result["storage"].set_name == random_name - assert result["storage"]["set_name"] == random_name + assert result["storage"]["namespace"] == DEFAULT_NAMESPACE + assert result["storage"].set_name == index + assert result["storage"]["set_name"] == index # Defaults assert result["sets"] == "" @@ -53,26 +46,17 @@ async def test_index_get(session_admin_client, empty_test_case, random_name): # assert result["hnsw_params"]["merge_params"]["index_parallelism"] == 80 # assert result["hnsw_params"]["merge_params"]["reindex_parallelism"] == 26 - await drop_specified_index(session_admin_client, "test", random_name) - @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -async def test_index_get_no_defaults(session_admin_client, empty_test_case, random_name): - await session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="science", - dimensions=1024, - ) +async def test_index_get_no_defaults(session_admin_client, empty_test_case, index): + result = await session_admin_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=False) - result = await session_admin_client.index_get(namespace="test", name=random_name, apply_defaults=False) - - assert result["id"]["name"] == random_name - assert result["id"]["namespace"] == "test" - assert result["dimensions"] == 1024 - assert result["field"] == "science" + assert result["id"]["name"] == index + assert result["id"]["namespace"] == DEFAULT_NAMESPACE + assert result["dimensions"] == DEFAULT_INDEX_DIMENSION + assert result["field"] == DEFAULT_VECTOR_FIELD # Defaults assert result["sets"] == "" @@ -100,14 +84,12 @@ async def test_index_get_no_defaults(session_admin_client, empty_test_case, rand assert result["storage"].set_name == "" assert result["storage"]["set_name"] == "" - await drop_specified_index(session_admin_client, "test", random_name) - @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) async def test_index_get_timeout( - session_admin_client, empty_test_case, random_name, with_latency + session_admin_client, empty_test_case, index, with_latency ): if not with_latency: @@ -116,7 +98,7 @@ async def test_index_get_timeout( for i in range(10): try: result = await session_admin_client.index_get( - namespace="test", name=random_name, timeout=0.0001 + namespace=DEFAULT_NAMESPACE, name=index, timeout=0.0001 ) except AVSServerError as se: if se.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: diff --git a/tests/standard/aio/test_admin_client_index_get_status.py b/tests/standard/aio/test_admin_client_index_get_status.py index 3f8bd4cb..fec3fc21 100644 --- a/tests/standard/aio/test_admin_client_index_get_status.py +++ b/tests/standard/aio/test_admin_client_index_get_status.py @@ -1,7 +1,6 @@ import pytest -from ...utils import random_name +from ...utils import DEFAULT_NAMESPACE -from .aio_utils import drop_specified_index from hypothesis import given, settings, Verbosity from aerospike_vector_search import types, AVSServerError @@ -11,46 +10,26 @@ @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -async def test_index_get_status(session_admin_client, empty_test_case, random_name): - try: - await session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="science", - dimensions=1024, - ) - except Exception as e: - pass - +async def test_index_get_status(session_admin_client, empty_test_case, index): result : types.IndexStatusResponse = await session_admin_client.index_get_status( - namespace="test", name=random_name + namespace=DEFAULT_NAMESPACE, name=index ) assert result.unmerged_record_count == 0 - await drop_specified_index(session_admin_client, "test", random_name) @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) async def test_index_get_status_timeout( - session_admin_client, empty_test_case, random_name, with_latency + session_admin_client, empty_test_case, index, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") - try: - await session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="science", - dimensions=1024, - ) - except Exception as e: - pass for i in range(10): try: result = await session_admin_client.index_get_status( - namespace="test", name=random_name, timeout=0.0001 + namespace=DEFAULT_NAMESPACE, name=index, timeout=0.0001 ) except AVSServerError as se: if se.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: diff --git a/tests/standard/aio/test_admin_client_index_list.py b/tests/standard/aio/test_admin_client_index_list.py index 9304320b..74c4bb9d 100644 --- a/tests/standard/aio/test_admin_client_index_list.py +++ b/tests/standard/aio/test_admin_client_index_list.py @@ -5,22 +5,15 @@ import pytest import grpc -from ...utils import random_name +from utils import DEFAULT_NAMESPACE, DEFAULT_VECTOR_FIELD, DEFAULT_INDEX_DIMENSION -from .aio_utils import drop_specified_index from hypothesis import given, settings, Verbosity @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -async def test_index_list(session_admin_client, empty_test_case, random_name): - await session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="science", - dimensions=1024, - ) +async def test_index_list(session_admin_client, empty_test_case, index): result = await session_admin_client.index_list(apply_defaults=True) assert len(result) > 0 for index in result: @@ -37,7 +30,6 @@ async def test_index_list(session_admin_client, empty_test_case, random_name): assert isinstance(index["hnsw_params"]["batching_params"]["reindex_interval"], int) assert isinstance(index["storage"]["namespace"], str) assert isinstance(index["storage"]["set_name"], str) - await drop_specified_index(session_admin_client, "test", random_name) async def test_index_list_timeout(session_admin_client, with_latency): diff --git a/tests/standard/aio/test_admin_client_index_update.py b/tests/standard/aio/test_admin_client_index_update.py index 6e6466d1..be86a5dc 100644 --- a/tests/standard/aio/test_admin_client_index_update.py +++ b/tests/standard/aio/test_admin_client_index_update.py @@ -3,7 +3,7 @@ from aerospike_vector_search import types, AVSServerError import grpc -from .aio_utils import drop_specified_index +from utils import DEFAULT_NAMESPACE server_defaults = { "m": 16, @@ -20,7 +20,6 @@ class index_update_test_case: def __init__( self, *, - namespace, vector_field, dimensions, initial_labels, @@ -28,7 +27,6 @@ def __init__( hnsw_index_update, timeout ): - self.namespace = namespace self.vector_field = vector_field self.dimensions = dimensions self.initial_labels = initial_labels @@ -41,7 +39,6 @@ def __init__( "test_case", [ index_update_test_case( - namespace="test", vector_field="update_2", dimensions=256, initial_labels={"status": "active"}, @@ -63,31 +60,11 @@ def __init__( ), ], ) -async def test_index_update_async(session_admin_client, test_case): - # Create a unique index name for each test run - trimmed_random = "aBEd-1" - - # Drop any pre-existing index with the same name - try: - session_admin_client.index_drop(namespace="test", name=trimmed_random) - except AVSServerError as se: - if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: - pass - - # Create the index - await session_admin_client.index_create( - namespace=test_case.namespace, - name=trimmed_random, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - index_labels=test_case.initial_labels, - timeout=test_case.timeout, - ) - +async def test_index_update_async(session_admin_client, test_case, index): # Update the index with new labels and parameters await session_admin_client.index_update( - namespace=test_case.namespace, - name=trimmed_random, + namespace=DEFAULT_NAMESPACE, + name=index, index_labels=test_case.update_labels, hnsw_update_params=test_case.hnsw_index_update ) @@ -96,10 +73,10 @@ async def test_index_update_async(session_admin_client, test_case): time.sleep(10) # Verify the update - result = await session_admin_client.index_get(namespace=test_case.namespace, name=trimmed_random, apply_defaults=True) + result = await session_admin_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=True) assert result, "Expected result to be non-empty but got an empty dictionary." - assert result["id"]["namespace"] == test_case.namespace + assert result["id"]["namespace"] == DEFAULT_NAMESPACE # Assertions based on provided parameters if test_case.hnsw_index_update.batching_params: @@ -131,6 +108,3 @@ async def test_index_update_async(session_admin_client, test_case): "max_scan_rate_per_node"] == test_case.hnsw_index_update.healer_params.max_scan_rate_per_node assert result["hnsw_params"]["enable_vector_integrity_check"] == test_case.hnsw_index_update.enable_vector_integrity_check - - # Clean up by dropping the index after the test - await drop_specified_index(session_admin_client, test_case.namespace, trimmed_random) diff --git a/tests/standard/aio/test_vector_client_delete.py b/tests/standard/aio/test_vector_client_delete.py index 8daf87e6..09bb935d 100644 --- a/tests/standard/aio/test_vector_client_delete.py +++ b/tests/standard/aio/test_vector_client_delete.py @@ -1,6 +1,6 @@ import pytest from aerospike_vector_search import AVSServerError -from ...utils import random_key +from utils import DEFAULT_NAMESPACE from hypothesis import given, settings, Verbosity import grpc @@ -11,13 +11,11 @@ def __init__( self, *, namespace, - record_data, set_name, timeout, ): self.namespace = namespace self.set_name = set_name - self.record_data = record_data self.timeout = timeout @@ -27,33 +25,27 @@ def __init__( "test_case", [ delete_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, set_name=None, - record_data={"skills": [i for i in range(1024)]}, timeout=None, ), delete_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, set_name=None, - record_data={"english": [float(i) for i in range(1024)]}, timeout=None, ), ], ) -async def test_vector_delete(session_vector_client, test_case, random_key): - await session_vector_client.upsert( - namespace=test_case.namespace, - key=random_key, - record_data=test_case.record_data, - set_name=test_case.set_name, - ) +async def test_vector_delete(session_vector_client, test_case, record): await session_vector_client.delete( namespace=test_case.namespace, - key=random_key, + key=record, + set_name=test_case.set_name, + timeout=test_case.timeout, ) with pytest.raises(AVSServerError) as e_info: result = await session_vector_client.get( - namespace=test_case.namespace, key=random_key + namespace=test_case.namespace, key=record ) @@ -63,19 +55,18 @@ async def test_vector_delete(session_vector_client, test_case, random_key): "test_case", [ delete_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, set_name=None, - record_data={"skills": [i for i in range(1024)]}, timeout=None, ), ], ) async def test_vector_delete_without_record( - session_vector_client, test_case, random_key + session_vector_client, test_case, record ): await session_vector_client.delete( namespace=test_case.namespace, - key=random_key, + key=record, ) @@ -86,15 +77,14 @@ async def test_vector_delete_without_record( [ None, delete_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, set_name=None, - record_data={"skills": [i for i in range(1024)]}, timeout=0.0001, ), ], ) async def test_vector_delete_timeout( - session_vector_client, test_case, random_key, with_latency + session_vector_client, test_case, record, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") @@ -102,6 +92,6 @@ async def test_vector_delete_timeout( for i in range(10): await session_vector_client.delete( - namespace=test_case.namespace, key=random_key, timeout=test_case.timeout + namespace=test_case.namespace, key=record, timeout=test_case.timeout ) assert e_info.value.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED diff --git a/tests/standard/aio/test_vector_client_exists.py b/tests/standard/aio/test_vector_client_exists.py index 699110fa..3b2b5167 100644 --- a/tests/standard/aio/test_vector_client_exists.py +++ b/tests/standard/aio/test_vector_client_exists.py @@ -1,6 +1,8 @@ import pytest import grpc -from ...utils import random_key + +from aerospike_vector_search import AVSServerError +from utils import DEFAULT_NAMESPACE from hypothesis import given, settings, Verbosity @@ -10,13 +12,11 @@ def __init__( self, *, namespace, - record_data, set_name, timeout, ): self.namespace = namespace self.set_name = set_name - self.record_data = record_data self.timeout = timeout @@ -26,35 +26,27 @@ def __init__( "test_case", [ exists_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, set_name=None, - record_data={"skills": [i for i in range(1024)]}, timeout=None, ), exists_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, set_name=None, - record_data={"english": [float(i) for i in range(1024)]}, timeout=None, ), ], ) -async def test_vector_exists(session_vector_client, test_case, random_key): - await session_vector_client.upsert( - namespace=test_case.namespace, - key=random_key, - record_data=test_case.record_data, - set_name=test_case.set_name, - ) +async def test_vector_exists(session_vector_client, test_case, record): result = await session_vector_client.exists( namespace=test_case.namespace, - key=random_key, + key=record, ) assert result is True await session_vector_client.delete( namespace=test_case.namespace, - key=random_key, + key=record, ) @@ -64,18 +56,18 @@ async def test_vector_exists(session_vector_client, test_case, random_key): "test_case", [ exists_test_case( - namespace="test", set_name=None, record_data=None, timeout=0.0001 + namespace="test", set_name=None, timeout=0.0001 ), ], ) async def test_vector_exists_timeout( - session_vector_client, test_case, random_key, with_latency + session_vector_client, test_case, record, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") with pytest.raises(AVSServerError) as e_info: for i in range(10): result = await session_vector_client.exists( - namespace=test_case.namespace, key=random_key, timeout=test_case.timeout + namespace=test_case.namespace, key=record, timeout=test_case.timeout ) assert e_info.value.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED diff --git a/tests/standard/aio/test_vector_client_get.py b/tests/standard/aio/test_vector_client_get.py index b354d773..fa20568b 100644 --- a/tests/standard/aio/test_vector_client_get.py +++ b/tests/standard/aio/test_vector_client_get.py @@ -1,12 +1,29 @@ import pytest -from aerospike_vector_search import AVSServerError import grpc -from ...utils import random_key +from aerospike_vector_search import AVSServerError +from utils import DEFAULT_NAMESPACE, random_key from hypothesis import given, settings, Verbosity +# gen_records is used with the records test fixture +def gen_record(count, vec_bin, vec_dim): + num = 0 + while num < count: + key_and_rec = ( + num, + { + "bin1": num, + "bin2": num, + "bin3": num, + vec_bin: [float(num)] * vec_dim + } + ) + yield key_and_rec + num += 1 + + class get_test_case: def __init__( self, @@ -15,7 +32,6 @@ def __init__( include_fields, exclude_fields, set_name, - record_data, expected_fields, timeout, ): @@ -23,7 +39,6 @@ def __init__( self.include_fields = include_fields self.exclude_fields = exclude_fields self.set_name = set_name - self.record_data = record_data self.expected_fields = expected_fields self.timeout = timeout @@ -31,83 +46,92 @@ def __init__( #@given(random_key=key_strategy()) #@settings(max_examples=1, deadline=1000) @pytest.mark.parametrize( - "test_case", + "record,test_case", [ - get_test_case( - namespace="test", - include_fields=["skills"], - exclude_fields = None, - set_name=None, - record_data={"skills": [i for i in range(1024)]}, - expected_fields={"skills": [i for i in range(1024)]}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"skills": 1024}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=["skills"], + exclude_fields = None, + set_name=None, + expected_fields={"skills": 1024}, + timeout=None, + ), ), - get_test_case( - namespace="test", - include_fields=["english"], - exclude_fields = None, - set_name=None, - record_data={"english": [float(i) for i in range(1024)]}, - expected_fields={"english": [float(i) for i in range(1024)]}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": [float(i) for i in range(1024)]}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=["english"], + exclude_fields = None, + set_name=None, + expected_fields={"english": [float(i) for i in range(1024)]}, + timeout=None, + ), ), - get_test_case( - namespace="test", - include_fields=["english"], - exclude_fields = None, - set_name=None, - record_data={"english": 1, "spanish": 2}, - expected_fields={"english": 1}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": 1, "spanish": 2}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=["english"], + exclude_fields = None, + set_name=None, + expected_fields={"english": 1}, + timeout=None, + ), ), - get_test_case( - namespace="test", - include_fields=None, - exclude_fields=["spanish"], - set_name=None, - record_data={"english": 1, "spanish": 2}, - expected_fields={"english": 1}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": 1, "spanish": 2}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=None, + exclude_fields=["spanish"], + set_name=None, + expected_fields={"english": 1}, + timeout=None, + ), ), - get_test_case( - namespace="test", - include_fields=["spanish"], - exclude_fields=["spanish"], - set_name=None, - record_data={"english": 1, "spanish": 2}, - expected_fields={}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": 1, "spanish": 2}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=["spanish"], + exclude_fields=["spanish"], + set_name=None, + expected_fields={}, + timeout=None, + ), ), - get_test_case( - namespace="test", - include_fields=[], - exclude_fields=None, - set_name=None, - record_data={"english": 1, "spanish": 2}, - expected_fields={}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": 1, "spanish": 2}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=[], + exclude_fields=None, + set_name=None, + expected_fields={}, + timeout=None, + ), ), - get_test_case( - namespace="test", - include_fields=None, - exclude_fields=[], - set_name=None, - record_data={"english": 1, "spanish": 2}, - expected_fields={"english": 1, "spanish": 2}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": 1, "spanish": 2}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=None, + exclude_fields=[], + set_name=None, + expected_fields={"english": 1, "spanish": 2}, + timeout=None, + ), ), ], + indirect=["record"], ) -async def test_vector_get(session_vector_client, test_case, random_key): - await session_vector_client.upsert( - namespace=test_case.namespace, - key=random_key, - record_data=test_case.record_data, - set_name=test_case.set_name, - ) +async def test_vector_get(session_vector_client, test_case, record): result = await session_vector_client.get( namespace=test_case.namespace, - key=random_key, + key=record, include_fields=test_case.include_fields, exclude_fields=test_case.exclude_fields, ) @@ -115,15 +139,10 @@ async def test_vector_get(session_vector_client, test_case, random_key): if test_case.set_name == None: test_case.set_name = "" assert result.key.set == test_case.set_name - assert result.key.key == random_key + assert result.key.key == record assert result.fields == test_case.expected_fields - await session_vector_client.delete( - namespace=test_case.namespace, - key=random_key, - ) - #@given(random_key=key_strategy()) #@settings(max_examples=1, deadline=1000) @@ -135,7 +154,6 @@ async def test_vector_get(session_vector_client, test_case, random_key): include_fields=["skills"], exclude_fields = None, set_name=None, - record_data=None, expected_fields=None, timeout=0.0001, ), diff --git a/tests/standard/aio/test_vector_client_insert.py b/tests/standard/aio/test_vector_client_insert.py index ec503c5b..c0a1645d 100644 --- a/tests/standard/aio/test_vector_client_insert.py +++ b/tests/standard/aio/test_vector_client_insert.py @@ -62,6 +62,10 @@ async def test_vector_insert_without_existing_record( record_data=test_case.record_data, set_name=test_case.set_name, ) + await session_vector_client.delete( + namespace=test_case.namespace, + key=random_key, + ) #@given(random_key=key_strategy()) @@ -78,12 +82,12 @@ async def test_vector_insert_without_existing_record( ], ) async def test_vector_insert_with_existing_record( - session_vector_client, test_case, random_key + session_vector_client, test_case, record ): try: await session_vector_client.insert( namespace=test_case.namespace, - key=random_key, + key=record, record_data=test_case.record_data, set_name=test_case.set_name, ) @@ -93,14 +97,10 @@ async def test_vector_insert_with_existing_record( with pytest.raises(AVSServerError) as e_info: await session_vector_client.insert( namespace=test_case.namespace, - key=random_key, + key=record, record_data=test_case.record_data, set_name=test_case.set_name, ) - await session_vector_client.delete( - namespace=test_case.namespace, - key=random_key, - ) #@given(random_key=key_strategy()) @@ -130,4 +130,8 @@ async def test_vector_insert_timeout( set_name=test_case.set_name, timeout=test_case.timeout, ) + await session_vector_client.delete( + namespace=test_case.namespace, + key=random_key, + ) assert e_info.value.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED diff --git a/tests/standard/aio/test_vector_client_is_indexed.py b/tests/standard/aio/test_vector_client_is_indexed.py index d1ae8519..7a7fae58 100644 --- a/tests/standard/aio/test_vector_client_is_indexed.py +++ b/tests/standard/aio/test_vector_client_is_indexed.py @@ -1,81 +1,36 @@ import pytest -import random +import time + +from utils import DEFAULT_NAMESPACE from aerospike_vector_search import AVSServerError -from aerospike_vector_search.aio import Client, AdminClient -import asyncio import grpc -# Define module-level constants for common arguments -NAMESPACE = "test" -SET_NAME = "isidxset" -INDEX_NAME = "isidx" -DIMENSIONS = 3 -VECTOR_FIELD = "vector" - -@pytest.fixture(scope="module", autouse=True) -async def setup_teardown(session_vector_client: Client, session_admin_client: AdminClient): - # Setup: Create index and upsert records - await session_admin_client.index_create( - namespace=NAMESPACE, - sets=SET_NAME, - name=INDEX_NAME, - dimensions=DIMENSIONS, - vector_field=VECTOR_FIELD, - ) - tasks = [] - for i in range(10): - tasks.append(session_vector_client.upsert( - namespace=NAMESPACE, - set_name=SET_NAME, - key=str(i), - record_data={VECTOR_FIELD: [float(i)] * DIMENSIONS}, - )) - tasks.append(session_vector_client.wait_for_index_completion( - namespace=NAMESPACE, - name=INDEX_NAME, - )) - await asyncio.gather(*tasks) - yield - # Teardown: remove records - tasks = [] - for i in range(10): - tasks.append(session_vector_client.delete( - namespace=NAMESPACE, - set_name=SET_NAME, - key=str(i) - )) - await asyncio.gather(*tasks) - - # Teardown: Drop index - await session_admin_client.index_drop( - namespace=NAMESPACE, - name=INDEX_NAME - ) async def test_vector_is_indexed( - session_vector_client, + session_vector_client, index, record ): + # give the record some time to be indexed + time.sleep(1) result = await session_vector_client.is_indexed( - namespace=NAMESPACE, - key="0", - index_name=INDEX_NAME, - set_name=SET_NAME, + namespace=DEFAULT_NAMESPACE, + key=record, + index_name=index, ) assert result is True async def test_vector_is_indexed_timeout( - session_vector_client, with_latency + session_vector_client, with_latency, index, record ): if not with_latency: pytest.skip("Server latency too low to test timeout") for _ in range(10): try: await session_vector_client.is_indexed( - namespace=NAMESPACE, - key="0", - index_name=INDEX_NAME, + namespace=DEFAULT_NAMESPACE, + key=record, + index_name=index, timeout=0.0001, ) except AVSServerError as se: diff --git a/tests/standard/aio/test_vector_client_search_by_key.py b/tests/standard/aio/test_vector_client_search_by_key.py index faabbec5..a2294ec2 100644 --- a/tests/standard/aio/test_vector_client_search_by_key.py +++ b/tests/standard/aio/test_vector_client_search_by_key.py @@ -1,14 +1,21 @@ import numpy as np -import asyncio import pytest + +from utils import DEFAULT_NAMESPACE +from .aio_utils import wait_for_index from aerospike_vector_search import types +INDEX = "sbk_index" +NAMESPACE = DEFAULT_NAMESPACE +DIMENSIONS = 3 +VEC_BIN = "vector" +SET_NAME = "test_set" + class vector_search_by_key_test_case: def __init__( self, *, - index_name, index_dimensions, vector_field, limit, @@ -18,10 +25,8 @@ def __init__( include_fields, exclude_fields, key_set, - record_data, expected_results, ): - self.index_name = index_name self.index_dimensions = index_dimensions self.vector_field = vector_field self.limit = limit @@ -30,45 +35,120 @@ def __init__( self.include_fields = include_fields self.exclude_fields = exclude_fields self.key_set = key_set - self.record_data = record_data self.expected_results = expected_results self.key_namespace = key_namespace -# TODO add a teardown + +@pytest.fixture(scope="module", autouse=True) +async def setup_index( + session_admin_client, +): + await session_admin_client.index_create( + namespace=DEFAULT_NAMESPACE, + name=INDEX, + vector_field=VEC_BIN, + dimensions=DIMENSIONS, + index_params=types.HnswParams( + batching_params=types.HnswBatchingParams( + # 10_000 is the minimum value, in order for the tests to run as + # fast as possible we set it to the minimum value so records are indexed + # quickly + index_interval=10_000, + ), + healer_params=types.HnswHealerParams( + # run the healer every second + # for fast indexing + schedule="* * * * * ?" + ) + ) + ) + + yield + + await session_admin_client.index_drop( + namespace=DEFAULT_NAMESPACE, + name=INDEX, + ) + + +@pytest.fixture(scope="module", autouse=True) +async def setup_records( + session_vector_client, +): + recs = { + "rec1": { + "bin": 1, + VEC_BIN: [1.0] * DIMENSIONS, + }, + 2: { + "bin": 2, + VEC_BIN: [2.0] * DIMENSIONS, + }, + bytes("rec5", "utf-8"): { + "bin": 5, + VEC_BIN: [5.0] * DIMENSIONS, + }, + } + + keys = [] + for key, record in recs.items(): + await session_vector_client.upsert( + namespace=DEFAULT_NAMESPACE, + key=key, + record_data=record, + ) + keys.append(key) + + # write some records for set tests + set_recs = { + "srec100": { + "bin": 100, + VEC_BIN: [100.0] * DIMENSIONS, + }, + "srec101": { + "bin": 101, + VEC_BIN: [101.0] * DIMENSIONS, + }, + } + + for key, record in set_recs.items(): + await session_vector_client.upsert( + namespace=DEFAULT_NAMESPACE, + key=key, + record_data=record, + set_name=SET_NAME, + ) + keys.append(key) + + yield + + for key in keys: + await session_vector_client.delete( + namespace=DEFAULT_NAMESPACE, + key=key, + ) + + + #@settings(max_examples=1, deadline=1000) @pytest.mark.parametrize( "test_case", [ # test string key vector_search_by_key_test_case( - index_name="basic_search", index_dimensions=3, vector_field="vector", limit=2, key="rec1", - key_namespace="test", - search_namespace="test", + key_namespace=DEFAULT_NAMESPACE, + search_namespace=DEFAULT_NAMESPACE, include_fields=None, exclude_fields=None, key_set=None, - record_data={ - "rec1": { - "bin": 1, - "vector": [1.0, 1.0, 1.0], - }, - "rec2": { - "bin": 2, - "vector": [2.0, 2.0, 2.0], - }, - "rec3": { - "bin": 3, - "vector": [3.0, 3.0, 3.0], - }, - }, expected_results=[ types.Neighbor( key=types.Key( - namespace="test", + namespace=DEFAULT_NAMESPACE, set="", key="rec1", ), @@ -80,9 +160,9 @@ def __init__( ), types.Neighbor( key=types.Key( - namespace="test", + namespace=DEFAULT_NAMESPACE, set="", - key="rec2", + key=2, ), fields={ "bin": 2, @@ -94,202 +174,159 @@ def __init__( ), # test int key vector_search_by_key_test_case( - index_name="field_filter", index_dimensions=3, vector_field="vector", limit=3, - key=1, - key_namespace="test", - search_namespace="test", + key=2, + key_namespace=DEFAULT_NAMESPACE, + search_namespace=DEFAULT_NAMESPACE, include_fields=["bin"], exclude_fields=["bin"], key_set=None, - record_data={ - 1: { - "bin": 1, - "vector": [1.0, 1.0, 1.0], - }, - 2: { - "bin": 2, - "vector": [2.0, 2.0, 2.0], - }, - }, expected_results=[ types.Neighbor( key=types.Key( - namespace="test", + namespace=DEFAULT_NAMESPACE, set="", - key=1, + key=2, ), fields={}, distance=0.0, ), types.Neighbor( key=types.Key( - namespace="test", + namespace=DEFAULT_NAMESPACE, set="", - key=2, + key="rec1", ), fields={}, distance=3.0, ), + types.Neighbor( + key=types.Key( + namespace=DEFAULT_NAMESPACE, + set="", + key=bytes("rec5", "utf-8"), + ), + fields={}, + distance=27.0, + ), ], ), # test bytes key vector_search_by_key_test_case( - index_name="field_filter", index_dimensions=3, vector_field="vector", limit=3, - key=bytes("rec1", "utf-8"), - key_namespace="test", - search_namespace="test", + key=bytes("rec5", "utf-8"), + key_namespace=DEFAULT_NAMESPACE, + search_namespace=DEFAULT_NAMESPACE, include_fields=["bin"], exclude_fields=["bin"], key_set=None, - record_data={ - bytes("rec1", "utf-8"): { - "bin": 1, - "vector": [1.0, 1.0, 1.0], - }, - bytes("rec2", "utf-8"): { - "bin": 2, - "vector": [2.0, 2.0, 2.0], - }, - }, expected_results=[ types.Neighbor( key=types.Key( - namespace="test", + namespace=DEFAULT_NAMESPACE, set="", - key=bytes("rec1", "utf-8"), + key=bytes("rec5", "utf-8"), ), fields={}, distance=0.0, ), types.Neighbor( key=types.Key( - namespace="test", + namespace=DEFAULT_NAMESPACE, set="", - key=bytes("rec2", "utf-8"), + key=2, ), fields={}, - distance=3.0, + distance=27.0, + ), + types.Neighbor( + key=types.Key( + namespace=DEFAULT_NAMESPACE, + set="", + key="rec1", + ), + fields={}, + distance=48.0, ), ], ), - # test bytearray key - # TODO: add a bytearray key case, bytearrays are not hashable - # so this is not easily added. Leaving it for now. - # vector_search_by_key_test_case( - # index_name="field_filter", - # index_dimensions=3, - # vector_field="vector", - # limit=3, - # key=bytearray("rec1", "utf-8"), - # namespace="test", - # include_fields=["bin"], - # exclude_fields=["bin"], - # key_set=None, - # record_data={ - # bytearray("rec1", "utf-8"): { - # "bin": 1, - # "vector": [1.0, 1.0, 1.0], - # }, - # bytearray("rec1", "utf-8"): { - # "bin": 2, - # "vector": [2.0, 2.0, 2.0], - # }, - # }, - # expected_results=[ - # types.Neighbor( - # key=types.Key( - # namespace="test", - # set="", - # key=2, - # ), - # fields={}, - # distance=3.0, - # ), - # ], - # ), + # # test bytearray key + # # TODO: add a bytearray key case, bytearrays are not hashable + # # so this is not easily added. Leaving it for now. + # # vector_search_by_key_test_case( + # # index_name="field_filter", + # # index_dimensions=3, + # # vector_field="vector", + # # limit=3, + # # key=bytearray("rec1", "utf-8"), + # # namespace=DEFAULT_NAMESPACE, + # # include_fields=["bin"], + # # exclude_fields=["bin"], + # # key_set=None, + # # record_data={ + # # bytearray("rec1", "utf-8"): { + # # "bin": 1, + # # "vector": [1.0, 1.0, 1.0], + # # }, + # # bytearray("rec1", "utf-8"): { + # # "bin": 2, + # # "vector": [2.0, 2.0, 2.0], + # # }, + # # }, + # # expected_results=[ + # # types.Neighbor( + # # key=types.Key( + # # namespace=DEFAULT_NAMESPACE, + # # set="", + # # key=2, + # # ), + # # fields={}, + # # distance=3.0, + # # ), + # # ], + # # ), # test with set name vector_search_by_key_test_case( - index_name="basic_search", index_dimensions=3, vector_field="vector", limit=2, - key="rec1", - key_namespace="test", - search_namespace="test", + key="srec100", + key_namespace=DEFAULT_NAMESPACE, + search_namespace=DEFAULT_NAMESPACE, include_fields=None, exclude_fields=None, - key_set="test_set", - record_data={ - "rec1": { - "bin": 1, - "vector": [1.0, 1.0, 1.0], - }, - "rec2": { - "bin": 2, - "vector": [2.0, 2.0, 2.0], - }, - }, + key_set=SET_NAME, expected_results=[ types.Neighbor( key=types.Key( - namespace="test", - set="test_set", - key="rec1", + namespace=DEFAULT_NAMESPACE, + set=SET_NAME, + key="srec100", ), fields={ - "bin": 1, - "vector": [1.0, 1.0, 1.0], + "bin": 100, + "vector": [100.0] * DIMENSIONS, }, distance=0.0, ), types.Neighbor( key=types.Key( - namespace="test", - set="test_set", - key="rec2", + namespace=DEFAULT_NAMESPACE, + set=SET_NAME, + key="srec101", ), fields={ - "bin": 2, - "vector": [2.0, 2.0, 2.0], + "bin": 101, + "vector": [101.0] * DIMENSIONS, }, distance=3.0, ), ], ), - # test search key record and search records are in different namespaces - vector_search_by_key_test_case( - index_name="basic_search", - index_dimensions=3, - vector_field="vector", - limit=2, - key="rec1", - key_namespace="test", - search_namespace="index_storage", - include_fields=None, - exclude_fields=None, - key_set=None, - record_data={ - "rec1": { - "bin": 1, - "vector": [1.0, 1.0, 1.0], - }, - "rec2": { - "bin": 2, - "vector": [2.0, 2.0, 2.0], - }, - "rec3": { - "bin": 3, - "vector": [3.0, 3.0, 3.0], - }, - }, - expected_results=[], - ), ], ) async def test_vector_search_by_key( @@ -297,34 +334,11 @@ async def test_vector_search_by_key( session_admin_client, test_case, ): - - await session_admin_client.index_create( - namespace=test_case.search_namespace, - name=test_case.index_name, - vector_field=test_case.vector_field, - dimensions=test_case.index_dimensions, - ) - - tasks = [] - for key, rec in test_case.record_data.items(): - tasks.append(session_vector_client.upsert( - namespace=test_case.key_namespace, - key=key, - record_data=rec, - set_name=test_case.key_set, - )) - - tasks.append( - session_vector_client.wait_for_index_completion( - namespace=test_case.search_namespace, - name=test_case.index_name, - ) - ) - await asyncio.gather(*tasks) + await wait_for_index(session_admin_client, DEFAULT_NAMESPACE, INDEX) results = await session_vector_client.vector_search_by_key( search_namespace=test_case.search_namespace, - index_name=test_case.index_name, + index_name=INDEX, key=test_case.key, key_namespace=test_case.key_namespace, vector_field=test_case.vector_field, @@ -336,21 +350,6 @@ async def test_vector_search_by_key( assert results == test_case.expected_results - tasks = [] - for key in test_case.record_data: - tasks.append(session_vector_client.delete( - namespace=test_case.key_namespace, - set_name=test_case.key_set, - key=key, - )) - - await asyncio.gather(*tasks) - - await session_admin_client.index_drop( - namespace=test_case.search_namespace, - name=test_case.index_name, - ) - async def test_vector_search_by_key_different_namespaces( session_vector_client, @@ -382,10 +381,7 @@ async def test_vector_search_by_key_different_namespaces( }, ) - await session_vector_client.wait_for_index_completion( - namespace="index_storage", - name="diff_ns_idx", - ) + await wait_for_index(session_admin_client, "index_storage", "diff_ns_idx") results = await session_vector_client.vector_search_by_key( search_namespace="index_storage", diff --git a/tests/standard/aio/test_vector_client_update.py b/tests/standard/aio/test_vector_client_update.py index 6c31dcb1..4ec8bb55 100644 --- a/tests/standard/aio/test_vector_client_update.py +++ b/tests/standard/aio/test_vector_client_update.py @@ -1,6 +1,6 @@ import pytest from aerospike_vector_search import AVSServerError -from ...utils import random_key +from ...utils import random_key, DEFAULT_NAMESPACE from hypothesis import given, settings, Verbosity import grpc @@ -27,19 +27,19 @@ def __init__( "test_case", [ update_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, timeout=None, ), update_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"english": [float(i) for i in range(1024)]}, set_name=None, timeout=None, ), update_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"english": [bool(i) for i in range(1024)]}, set_name=None, timeout=None, @@ -47,27 +47,14 @@ def __init__( ], ) async def test_vector_update_with_existing_record( - session_vector_client, test_case, random_key + session_vector_client, test_case, record ): - try: - await session_vector_client.insert( - namespace=test_case.namespace, - key=random_key, - record_data=test_case.record_data, - set_name=test_case.set_name, - ) - except Exception as e: - pass await session_vector_client.update( namespace=test_case.namespace, - key=random_key, + key=record, record_data=test_case.record_data, set_name=test_case.set_name, ) - await session_vector_client.delete( - namespace=test_case.namespace, - key=random_key, - ) #@given(random_key=key_strategy()) @@ -76,7 +63,7 @@ async def test_vector_update_with_existing_record( "test_case", [ update_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, timeout=None, @@ -106,7 +93,7 @@ async def test_vector_update_without_existing_record( [ None, update_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, timeout=0.0001, diff --git a/tests/standard/aio/test_vector_client_upsert.py b/tests/standard/aio/test_vector_client_upsert.py index 5cacc11d..cac115b4 100644 --- a/tests/standard/aio/test_vector_client_upsert.py +++ b/tests/standard/aio/test_vector_client_upsert.py @@ -1,5 +1,5 @@ import pytest -from ...utils import random_key +from ...utils import random_key, DEFAULT_NAMESPACE from hypothesis import given, settings, Verbosity import numpy as np @@ -24,19 +24,19 @@ def __init__(self, *, namespace, record_data, set_name, timeout, key=None): "test_case", [ upsert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, timeout=None, ), upsert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"english": [float(i) for i in range(1024)]}, set_name=None, timeout=None, ), upsert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"english": [bool(i) for i in range(1024)]}, set_name=None, timeout=None, @@ -65,7 +65,7 @@ async def test_vector_upsert_without_existing_record( "test_case", [ upsert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, timeout=None, @@ -92,14 +92,14 @@ async def test_vector_upsert_with_existing_record( "test_case", [ upsert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, timeout=None, key=np.int32(31), ), upsert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, timeout=None, @@ -128,7 +128,7 @@ async def test_vector_upsert_with_numpy_key(session_vector_client, test_case): [ None, upsert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, timeout=0.0001, diff --git a/tests/standard/aio/test_vector_search.py b/tests/standard/aio/test_vector_search.py index 88cab272..ec690f74 100644 --- a/tests/standard/aio/test_vector_search.py +++ b/tests/standard/aio/test_vector_search.py @@ -1,7 +1,10 @@ import numpy as np import asyncio import pytest + from aerospike_vector_search import types +from utils import DEFAULT_NAMESPACE +from .aio_utils import wait_for_index class vector_search_test_case: @@ -32,6 +35,7 @@ def __init__( self.record_data = record_data self.expected_results = expected_results + # TODO add a teardown #@settings(max_examples=1, deadline=1000) @pytest.mark.parametrize( @@ -43,7 +47,7 @@ def __init__( vector_field="vector", limit=3, query=[0.0, 0.0, 0.0], - namespace="test", + namespace=DEFAULT_NAMESPACE, include_fields=None, exclude_fields = None, set_name=None, @@ -56,7 +60,7 @@ def __init__( expected_results=[ types.Neighbor( key=types.Key( - namespace="test", + namespace=DEFAULT_NAMESPACE, set="", key="rec1", ), @@ -74,7 +78,7 @@ def __init__( vector_field="vector", limit=3, query=[0.0, 0.0, 0.0], - namespace="test", + namespace=DEFAULT_NAMESPACE, include_fields=["bin1"], exclude_fields=["bin1"], set_name=None, @@ -87,7 +91,7 @@ def __init__( expected_results=[ types.Neighbor( key=types.Key( - namespace="test", + namespace=DEFAULT_NAMESPACE, set="", key="rec1", ), @@ -109,6 +113,19 @@ async def test_vector_search( name=test_case.index_name, vector_field=test_case.vector_field, dimensions=test_case.index_dimensions, + index_params=types.HnswParams( + batching_params=types.HnswBatchingParams( + # 10_000 is the minimum value, in order for the tests to run as + # fast as possible we set it to the minimum value so records are indexed + # quickly + index_interval=10_000, + ), + healer_params=types.HnswHealerParams( + # run the healer every second + # for fast indexing + schedule="* * * * * ?" + ) + ) ) tasks = [] @@ -121,9 +138,10 @@ async def test_vector_search( )) tasks.append( - session_vector_client.wait_for_index_completion( + wait_for_index( + session_admin_client, namespace=test_case.namespace, - name=test_case.index_name, + index=test_case.index_name, ) ) await asyncio.gather(*tasks) diff --git a/tests/utils.py b/tests/utils.py index 6d3d222c..ddcb1306 100755 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,6 +5,13 @@ import pytest + +# default test values +DEFAULT_NAMESPACE = "test" +DEFAULT_INDEX_DIMENSION = 128 +DEFAULT_VECTOR_FIELD = "vector" + + def random_int(): return str(random.randint(0, 50_000)) From cbdced20dbd2ef009e54e49abcb966903429e756 Mon Sep 17 00:00:00 2001 From: dylan Date: Fri, 20 Dec 2024 12:14:41 -0800 Subject: [PATCH 02/21] fix index create async tests --- tests/standard/aio/aio_utils.py | 5 +++ .../aio/test_admin_client_index_create.py | 38 +++++++++---------- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/tests/standard/aio/aio_utils.py b/tests/standard/aio/aio_utils.py index c07fcf1f..867ec109 100644 --- a/tests/standard/aio/aio_utils.py +++ b/tests/standard/aio/aio_utils.py @@ -1,5 +1,10 @@ import asyncio + +async def drop_specified_index(admin_client, namespace, name): + admin_client.index_drop(namespace=namespace, name=name) + + def gen_records(count: int, vec_bin: str, vec_dim: int): num = 0 while num < count: diff --git a/tests/standard/aio/test_admin_client_index_create.py b/tests/standard/aio/test_admin_client_index_create.py index 251c9252..0897afca 100644 --- a/tests/standard/aio/test_admin_client_index_create.py +++ b/tests/standard/aio/test_admin_client_index_create.py @@ -2,7 +2,7 @@ from aerospike_vector_search import types, AVSServerError import grpc -from ...utils import random_name +from ...utils import random_name, DEFAULT_NAMESPACE from .aio_utils import drop_specified_index from hypothesis import given, settings, Verbosity, Phase @@ -51,7 +51,7 @@ def __init__( "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_1", dimensions=1024, vector_distance_metric=None, @@ -105,7 +105,7 @@ async def test_index_create(session_admin_client, test_case, random_name): "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_2", dimensions=495, vector_distance_metric=None, @@ -116,7 +116,7 @@ async def test_index_create(session_admin_client, test_case, random_name): timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_3", dimensions=2048, vector_distance_metric=None, @@ -174,7 +174,7 @@ async def test_index_create_with_dimnesions( "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_4", dimensions=1024, vector_distance_metric=types.VectorDistanceMetric.COSINE, @@ -185,7 +185,7 @@ async def test_index_create_with_dimnesions( timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_5", dimensions=1024, vector_distance_metric=types.VectorDistanceMetric.DOT_PRODUCT, @@ -196,7 +196,7 @@ async def test_index_create_with_dimnesions( timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_6", dimensions=1024, vector_distance_metric=types.VectorDistanceMetric.MANHATTAN, @@ -207,7 +207,7 @@ async def test_index_create_with_dimnesions( timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_7", dimensions=1024, vector_distance_metric=types.VectorDistanceMetric.HAMMING, @@ -262,7 +262,7 @@ async def test_index_create_with_vector_distance_metric( "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_8", dimensions=1024, vector_distance_metric=None, @@ -273,7 +273,7 @@ async def test_index_create_with_vector_distance_metric( timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_9", dimensions=1024, vector_distance_metric=None, @@ -326,7 +326,7 @@ async def test_index_create_with_sets(session_admin_client, test_case, random_na "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_10", dimensions=1024, vector_distance_metric=None, @@ -342,7 +342,7 @@ async def test_index_create_with_sets(session_admin_client, test_case, random_na timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_11", dimensions=1024, vector_distance_metric=None, @@ -358,7 +358,7 @@ async def test_index_create_with_sets(session_admin_client, test_case, random_na timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_12", dimensions=1024, vector_distance_metric=None, @@ -372,7 +372,7 @@ async def test_index_create_with_sets(session_admin_client, test_case, random_na timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_13", dimensions=1024, vector_distance_metric=None, @@ -386,7 +386,7 @@ async def test_index_create_with_sets(session_admin_client, test_case, random_na timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_20", dimensions=1024, vector_distance_metric=None, @@ -473,7 +473,7 @@ async def test_index_create_with_index_params( "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_14", dimensions=1024, vector_distance_metric=None, @@ -527,14 +527,14 @@ async def test_index_create_index_labels(session_admin_client, test_case, random "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_15", dimensions=1024, vector_distance_metric=None, sets=None, index_params=None, index_labels=None, - index_storage=types.IndexStorage(namespace="test", set_name="foo"), + index_storage=types.IndexStorage(namespace=DEFAULT_NAMESPACE, set_name="foo"), timeout=None, ), ], @@ -578,7 +578,7 @@ async def test_index_create_index_storage(session_admin_client, test_case, rando "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_16", dimensions=1024, vector_distance_metric=None, From 4d054b3ca4a6dea4c9fe07941ccea0aca955faaf Mon Sep 17 00:00:00 2001 From: dylan Date: Fri, 20 Dec 2024 12:52:19 -0800 Subject: [PATCH 03/21] speed up vector record indexing in test indexes --- tests/standard/aio/conftest.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/standard/aio/conftest.py b/tests/standard/aio/conftest.py index c619b802..354bba04 100644 --- a/tests/standard/aio/conftest.py +++ b/tests/standard/aio/conftest.py @@ -210,6 +210,19 @@ async def index(session_admin_client, index_name, request): namespace = namespace, vector_field = vector_field, dimensions = dimensions, + index_params=types.HnswParams( + batching_params=types.HnswBatchingParams( + # 10_000 is the minimum value, in order for the tests to run as + # fast as possible we set it to the minimum value so records are indexed + # quickly + index_interval=10_000, + ), + healer_params=types.HnswHealerParams( + # run the healer every second + # for fast indexing + schedule="* * * * * ?" + ) + ) ) yield index_name try: From c7df29a249a6d9c5053e246201865b9b6601d81f Mon Sep 17 00:00:00 2001 From: dylan Date: Fri, 20 Dec 2024 14:06:29 -0800 Subject: [PATCH 04/21] refactor(tests): propogate record and index fixtures through sync tests --- src/aerospike_vector_search/aio/admin.py | 2 +- .../aio/test_admin_client_index_list.py | 5 - .../aio/test_admin_client_index_update.py | 6 +- .../standard/aio/test_vector_client_delete.py | 6 +- .../standard/aio/test_vector_client_exists.py | 5 - .../standard/aio/test_vector_client_insert.py | 22 +-- .../aio/test_vector_client_is_indexed.py | 15 +- .../standard/aio/test_vector_client_update.py | 1 - tests/standard/sync/conftest.py | 66 +++++++- tests/standard/sync/sync_utils.py | 22 +++ .../sync/test_admin_client_index_create.py | 56 +++---- .../sync/test_admin_client_index_drop.py | 40 +---- .../sync/test_admin_client_index_get.py | 77 +++------ .../test_admin_client_index_get_status.py | 22 +-- .../sync/test_admin_client_index_list.py | 27 +-- .../sync/test_admin_client_index_update.py | 38 +---- .../sync/test_vector_client_delete.py | 32 ++-- .../sync/test_vector_client_exists.py | 32 +--- tests/standard/sync/test_vector_client_get.py | 156 +++++++++--------- .../sync/test_vector_client_insert.py | 32 ++-- .../sync/test_vector_client_is_indexed.py | 74 +++------ .../sync/test_vector_client_update.py | 12 +- 22 files changed, 311 insertions(+), 437 deletions(-) diff --git a/src/aerospike_vector_search/aio/admin.py b/src/aerospike_vector_search/aio/admin.py index eae2a817..aa57bcd1 100644 --- a/src/aerospike_vector_search/aio/admin.py +++ b/src/aerospike_vector_search/aio/admin.py @@ -386,7 +386,7 @@ async def index_get_status( Note: This method retrieves the status of the specified index. If index_get_status is called the vector client puts some records into Aerospike Vector Search, - the records may not immediately begin to merge into the index. To wait for all records to be merged into an index, use vector_client.wait_for_index_completion. + the records may not immediately begin to merge into the index. Warning: This API is subject to change. """ diff --git a/tests/standard/aio/test_admin_client_index_list.py b/tests/standard/aio/test_admin_client_index_list.py index 74c4bb9d..af633e8c 100644 --- a/tests/standard/aio/test_admin_client_index_list.py +++ b/tests/standard/aio/test_admin_client_index_list.py @@ -1,12 +1,7 @@ -import pytest - from aerospike_vector_search import AVSServerError import pytest import grpc - -from utils import DEFAULT_NAMESPACE, DEFAULT_VECTOR_FIELD, DEFAULT_INDEX_DIMENSION - from hypothesis import given, settings, Verbosity diff --git a/tests/standard/aio/test_admin_client_index_update.py b/tests/standard/aio/test_admin_client_index_update.py index be86a5dc..734b1b77 100644 --- a/tests/standard/aio/test_admin_client_index_update.py +++ b/tests/standard/aio/test_admin_client_index_update.py @@ -1,10 +1,10 @@ import time -import pytest -from aerospike_vector_search import types, AVSServerError -import grpc +from aerospike_vector_search import types from utils import DEFAULT_NAMESPACE +import pytest + server_defaults = { "m": 16, "ef_construction": 100, diff --git a/tests/standard/aio/test_vector_client_delete.py b/tests/standard/aio/test_vector_client_delete.py index 09bb935d..f5479095 100644 --- a/tests/standard/aio/test_vector_client_delete.py +++ b/tests/standard/aio/test_vector_client_delete.py @@ -1,6 +1,6 @@ import pytest from aerospike_vector_search import AVSServerError -from utils import DEFAULT_NAMESPACE +from utils import random_key, DEFAULT_NAMESPACE from hypothesis import given, settings, Verbosity import grpc @@ -62,11 +62,11 @@ async def test_vector_delete(session_vector_client, test_case, record): ], ) async def test_vector_delete_without_record( - session_vector_client, test_case, record + session_vector_client, test_case, random_key ): await session_vector_client.delete( namespace=test_case.namespace, - key=record, + key=random_key, ) diff --git a/tests/standard/aio/test_vector_client_exists.py b/tests/standard/aio/test_vector_client_exists.py index 3b2b5167..42cdf5c9 100644 --- a/tests/standard/aio/test_vector_client_exists.py +++ b/tests/standard/aio/test_vector_client_exists.py @@ -44,11 +44,6 @@ async def test_vector_exists(session_vector_client, test_case, record): ) assert result is True - await session_vector_client.delete( - namespace=test_case.namespace, - key=record, - ) - #@given(random_key=key_strategy()) #@settings(max_examples=1, deadline=1000) diff --git a/tests/standard/aio/test_vector_client_insert.py b/tests/standard/aio/test_vector_client_insert.py index c0a1645d..04e587cf 100644 --- a/tests/standard/aio/test_vector_client_insert.py +++ b/tests/standard/aio/test_vector_client_insert.py @@ -1,6 +1,6 @@ import pytest from aerospike_vector_search import AVSServerError -from ...utils import random_key +from utils import random_key, DEFAULT_NAMESPACE from hypothesis import given, settings, Verbosity import asyncio @@ -30,19 +30,19 @@ def __init__( "test_case", [ insert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, timeout=None, ), insert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"homeSkills": [float(i) for i in range(1024)]}, set_name=None, timeout=None, ), insert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"english": [bool(i) for i in range(1024)]}, set_name=None, timeout=None, @@ -74,7 +74,7 @@ async def test_vector_insert_without_existing_record( "test_case", [ insert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, timeout=None, @@ -84,16 +84,6 @@ async def test_vector_insert_without_existing_record( async def test_vector_insert_with_existing_record( session_vector_client, test_case, record ): - try: - await session_vector_client.insert( - namespace=test_case.namespace, - key=record, - record_data=test_case.record_data, - set_name=test_case.set_name, - ) - except Exception as e: - pass - with pytest.raises(AVSServerError) as e_info: await session_vector_client.insert( namespace=test_case.namespace, @@ -109,7 +99,7 @@ async def test_vector_insert_with_existing_record( "test_case", [ insert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, timeout=0.0001, diff --git a/tests/standard/aio/test_vector_client_is_indexed.py b/tests/standard/aio/test_vector_client_is_indexed.py index 7a7fae58..f222d6b3 100644 --- a/tests/standard/aio/test_vector_client_is_indexed.py +++ b/tests/standard/aio/test_vector_client_is_indexed.py @@ -2,16 +2,25 @@ import time from utils import DEFAULT_NAMESPACE +from .aio_utils import wait_for_index from aerospike_vector_search import AVSServerError import grpc async def test_vector_is_indexed( - session_vector_client, index, record + session_admin_client, + session_vector_client, + index, + record ): - # give the record some time to be indexed - time.sleep(1) + # wait for the record to be indexed + await wait_for_index( + admin_client=session_admin_client, + namespace=DEFAULT_NAMESPACE, + index=index + ) + result = await session_vector_client.is_indexed( namespace=DEFAULT_NAMESPACE, key=record, diff --git a/tests/standard/aio/test_vector_client_update.py b/tests/standard/aio/test_vector_client_update.py index 4ec8bb55..fc5b5264 100644 --- a/tests/standard/aio/test_vector_client_update.py +++ b/tests/standard/aio/test_vector_client_update.py @@ -91,7 +91,6 @@ async def test_vector_update_without_existing_record( @pytest.mark.parametrize( "test_case", [ - None, update_test_case( namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, diff --git a/tests/standard/sync/conftest.py b/tests/standard/sync/conftest.py index 6e7a229b..ed0667a9 100644 --- a/tests/standard/sync/conftest.py +++ b/tests/standard/sync/conftest.py @@ -4,7 +4,7 @@ from aerospike_vector_search import Client from aerospike_vector_search.admin import Client as AdminClient -from aerospike_vector_search import types +from aerospike_vector_search import types, AVSServerError from .sync_utils import gen_records @@ -195,14 +195,37 @@ def index_name(): @pytest.fixture(params=[DEFAULT_INDEX_ARGS]) def index(session_admin_client, index_name, request): - index_args = request.param - session_admin_client.index_create( + args = request.param + namespace = args.get("namespace", DEFAULT_NAMESPACE) + vector_field = args.get("vector_field", DEFAULT_VECTOR_FIELD) + dimensions = args.get("dimensions", DEFAULT_INDEX_DIMENSION) + await session_admin_client.index_create( name = index_name, - **index_args, + namespace = namespace, + vector_field = vector_field, + dimensions = dimensions, + index_params=types.HnswParams( + batching_params=types.HnswBatchingParams( + # 10_000 is the minimum value, in order for the tests to run as + # fast as possible we set it to the minimum value so records are indexed + # quickly + index_interval=10_000, + ), + healer_params=types.HnswHealerParams( + # run the healer every second + # for fast indexing + schedule="* * * * * ?" + ) + ) ) yield index_name - namespace = index_args.get("namespace", DEFAULT_NAMESPACE) - session_admin_client.index_drop(namespace=namespace, name=index_name) + try: + session_admin_client.index_drop(namespace=namespace, name=index_name) + except AVSServerError as se: + if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: + pass + else: + raise @pytest.fixture(params=[DEFAULT_RECORDS_ARGS]) @@ -213,10 +236,35 @@ def records(session_vector_client, request): num_records = args.get("num_records", DEFAULT_NUM_RECORDS) vector_field = args.get("vector_field", DEFAULT_VECTOR_FIELD) dimensions = args.get("dimensions", DEFAULT_INDEX_DIMENSION) + set_name = args.get("set_name", None) keys = [] for key, rec in record_generator(count=num_records, vec_bin=vector_field, vec_dim=dimensions): - session_vector_client.upsert(namespace=namespace, key=key, record_data=rec) + session_vector_client.upsert( + namespace=namespace, + key=key, + record_data=rec, + set_name=set_name, + ) keys.append(key) - yield len(keys) + yield keys for key in keys: - session_vector_client.delete(key=key, namespace=namespace) \ No newline at end of file + session_vector_client.delete(key=key, namespace=namespace) + + +@pytest.fixture(params=[DEFAULT_RECORDS_ARGS]) +def record(session_vector_client, request): + args = request.param + record_generator = args.get("record_generator", DEFAULT_RECORD_GENERATOR) + namespace = args.get("namespace", DEFAULT_NAMESPACE) + vector_field = args.get("vector_field", DEFAULT_VECTOR_FIELD) + dimensions = args.get("dimensions", DEFAULT_INDEX_DIMENSION) + set_name = args.get("set_name", None) + key, rec = next(record_generator(count=1, vec_bin=vector_field, vec_dim=dimensions)) + session_vector_client.upsert( + namespace=namespace, + key=key, + record_data=rec, + set_name=set_name, + ) + yield key + session_vector_client.delete(key=key, namespace=namespace) \ No newline at end of file diff --git a/tests/standard/sync/sync_utils.py b/tests/standard/sync/sync_utils.py index b3936c8a..05aeeb94 100644 --- a/tests/standard/sync/sync_utils.py +++ b/tests/standard/sync/sync_utils.py @@ -1,6 +1,9 @@ +import time + def drop_specified_index(admin_client, namespace, name): admin_client.index_drop(namespace=namespace, name=name) + def gen_records(count: int, vec_bin: str, vec_dim: int): num = 0 while num < count: @@ -10,3 +13,22 @@ def gen_records(count: int, vec_bin: str, vec_dim: int): ) yield key_and_rec num += 1 + + +def wait_for_index(admin_client, namespace: str, index: str): + + verticies = 0 + unmerged_recs = 0 + + while verticies == 0 or unmerged_recs > 0: + status = admin_client.index_get_status( + namespace=namespace, + name=index, + ) + + verticies = status.index_healer_vertices_valid + unmerged_recs = status.unmerged_record_count + + # print(verticies) + # print(unmerged_recs) + time.sleep(0.5) \ No newline at end of file diff --git a/tests/standard/sync/test_admin_client_index_create.py b/tests/standard/sync/test_admin_client_index_create.py index dd23c75f..3b90fdba 100644 --- a/tests/standard/sync/test_admin_client_index_create.py +++ b/tests/standard/sync/test_admin_client_index_create.py @@ -2,9 +2,9 @@ import grpc from aerospike_vector_search import types, AVSServerError -from ...utils import random_name - +from ...utils import random_name, DEFAULT_NAMESPACE from .sync_utils import drop_specified_index + from hypothesis import given, settings, Verbosity server_defaults = { @@ -51,7 +51,7 @@ def __init__( "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_1", dimensions=1024, vector_distance_metric=None, @@ -65,7 +65,7 @@ def __init__( ) def test_index_create(session_admin_client, test_case, random_name): try: - session_admin_client.index_drop(namespace="test", name=random_name) + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass @@ -110,7 +110,7 @@ def test_index_create(session_admin_client, test_case, random_name): "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_2", dimensions=495, vector_distance_metric=None, @@ -121,7 +121,7 @@ def test_index_create(session_admin_client, test_case, random_name): timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_3", dimensions=2048, vector_distance_metric=None, @@ -136,7 +136,7 @@ def test_index_create(session_admin_client, test_case, random_name): def test_index_create_with_dimnesions(session_admin_client, test_case, random_name): try: - session_admin_client.index_drop(namespace="test", name=random_name) + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass @@ -184,7 +184,7 @@ def test_index_create_with_dimnesions(session_admin_client, test_case, random_na "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_4", dimensions=1024, vector_distance_metric=types.VectorDistanceMetric.COSINE, @@ -195,7 +195,7 @@ def test_index_create_with_dimnesions(session_admin_client, test_case, random_na timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_5", dimensions=1024, vector_distance_metric=types.VectorDistanceMetric.DOT_PRODUCT, @@ -206,7 +206,7 @@ def test_index_create_with_dimnesions(session_admin_client, test_case, random_na timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_6", dimensions=1024, vector_distance_metric=types.VectorDistanceMetric.MANHATTAN, @@ -217,7 +217,7 @@ def test_index_create_with_dimnesions(session_admin_client, test_case, random_na timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_7", dimensions=1024, vector_distance_metric=types.VectorDistanceMetric.HAMMING, @@ -234,7 +234,7 @@ def test_index_create_with_vector_distance_metric( ): try: - session_admin_client.index_drop(namespace="test", name=random_name) + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass @@ -278,7 +278,7 @@ def test_index_create_with_vector_distance_metric( "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_8", dimensions=1024, vector_distance_metric=None, @@ -289,7 +289,7 @@ def test_index_create_with_vector_distance_metric( timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_9", dimensions=1024, vector_distance_metric=None, @@ -304,7 +304,7 @@ def test_index_create_with_vector_distance_metric( def test_index_create_with_sets(session_admin_client, test_case, random_name): try: - session_admin_client.index_drop(namespace="test", name=random_name) + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass @@ -348,7 +348,7 @@ def test_index_create_with_sets(session_admin_client, test_case, random_name): "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_10", dimensions=1024, vector_distance_metric=None, @@ -364,7 +364,7 @@ def test_index_create_with_sets(session_admin_client, test_case, random_name): timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_11", dimensions=1024, vector_distance_metric=None, @@ -377,7 +377,7 @@ def test_index_create_with_sets(session_admin_client, test_case, random_name): timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_20", dimensions=1024, vector_distance_metric=None, @@ -390,7 +390,7 @@ def test_index_create_with_sets(session_admin_client, test_case, random_name): timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_12", dimensions=1024, vector_distance_metric=None, @@ -404,7 +404,7 @@ def test_index_create_with_sets(session_admin_client, test_case, random_name): timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_13", dimensions=1024, vector_distance_metric=None, @@ -432,7 +432,7 @@ def test_index_create_with_sets(session_admin_client, test_case, random_name): ) def test_index_create_with_index_params(session_admin_client, test_case, random_name): try: - session_admin_client.index_drop(namespace="test", name=random_name) + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass @@ -544,7 +544,7 @@ def test_index_create_with_index_params(session_admin_client, test_case, random_ "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_16", dimensions=1024, vector_distance_metric=None, @@ -558,7 +558,7 @@ def test_index_create_with_index_params(session_admin_client, test_case, random_ ) def test_index_create_index_labels(session_admin_client, test_case, random_name): try: - session_admin_client.index_drop(namespace="test", name=random_name) + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass @@ -605,21 +605,21 @@ def test_index_create_index_labels(session_admin_client, test_case, random_name) "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_17", dimensions=1024, vector_distance_metric=None, sets=None, index_params=None, index_labels=None, - index_storage=types.IndexStorage(namespace="test", set_name="foo"), + index_storage=types.IndexStorage(namespace=DEFAULT_NAMESPACE, set_name="foo"), timeout=None, ), ], ) def test_index_create_index_storage(session_admin_client, test_case, random_name): try: - session_admin_client.index_drop(namespace="test", name=random_name) + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass @@ -662,7 +662,7 @@ def test_index_create_index_storage(session_admin_client, test_case, random_name "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_18", dimensions=1024, vector_distance_metric=None, @@ -681,7 +681,7 @@ def test_index_create_timeout( if not with_latency: pytest.skip("Server latency too low to test timeout") try: - session_admin_client.index_drop(namespace="test", name=random_name) + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass diff --git a/tests/standard/sync/test_admin_client_index_drop.py b/tests/standard/sync/test_admin_client_index_drop.py index 0523510a..cf05a88f 100644 --- a/tests/standard/sync/test_admin_client_index_drop.py +++ b/tests/standard/sync/test_admin_client_index_drop.py @@ -3,8 +3,7 @@ from aerospike_vector_search import AVSServerError import grpc -from ...utils import random_name - +from utils import DEFAULT_NAMESPACE from hypothesis import given, settings, Verbosity @@ -12,52 +11,31 @@ @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -def test_index_drop(session_admin_client, empty_test_case, random_name): - try: - - session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="art", - dimensions=1024, - ) - - except AVSServerError as se: - if se.rpc_error.code() != grpc.StatusCode.ALREADY_EXISTS: - raise se - - session_admin_client.index_drop(namespace="test", name=random_name) +def test_index_drop(session_admin_client, empty_test_case, index): + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=index) result = session_admin_client.index_list() result = result for index in result: - assert index["id"]["name"] != random_name + assert index["id"]["name"] != index @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) def test_index_drop_timeout( - session_admin_client, empty_test_case, random_name, with_latency + session_admin_client, + empty_test_case, + index, + with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") - try: - session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="art", - dimensions=1024, - ) - except AVSServerError as se: - if se.rpc_error.code() != grpc.StatusCode.ALREADY_EXISTS: - raise se - for i in range(10): try: session_admin_client.index_drop( - namespace="test", name=random_name, timeout=0.0001 + namespace=DEFAULT_NAMESPACE, name=index, timeout=0.0001 ) except AVSServerError as se: if se.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: diff --git a/tests/standard/sync/test_admin_client_index_get.py b/tests/standard/sync/test_admin_client_index_get.py index 1831bcab..6cf0fd47 100644 --- a/tests/standard/sync/test_admin_client_index_get.py +++ b/tests/standard/sync/test_admin_client_index_get.py @@ -1,35 +1,20 @@ -import pytest -from ...utils import random_name - -from .sync_utils import drop_specified_index -from hypothesis import given, settings, Verbosity - +from ...utils import DEFAULT_NAMESPACE, DEFAULT_INDEX_DIMENSION, DEFAULT_VECTOR_FIELD from aerospike_vector_search import AVSServerError -import grpc +import grpc +from hypothesis import given, settings, Verbosity +import pytest @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -def test_index_get(session_admin_client, empty_test_case, random_name): - - try: - session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="science", - dimensions=1024, - ) - except AVSServerError as se: - if se.rpc_error.code() != grpc.StatusCode.ALREADY_EXISTS: - raise se - - result = session_admin_client.index_get(namespace="test", name=random_name, apply_defaults=True) - - assert result["id"]["name"] == random_name - assert result["id"]["namespace"] == "test" - assert result["dimensions"] == 1024 - assert result["field"] == "science" +def test_index_get(session_admin_client, empty_test_case, index): + result = session_admin_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=True) + + assert result["id"]["name"] == index + assert result["id"]["namespace"] == DEFAULT_NAMESPACE + assert result["dimensions"] == DEFAULT_INDEX_DIMENSION + assert result["field"] == DEFAULT_VECTOR_FIELD assert result["hnsw_params"]["m"] == 16 assert result["hnsw_params"]["ef_construction"] == 100 assert result["hnsw_params"]["ef"] == 100 @@ -37,8 +22,8 @@ def test_index_get(session_admin_client, empty_test_case, random_name): assert result["hnsw_params"]["batching_params"]["index_interval"] == 30000 assert result["hnsw_params"]["batching_params"]["max_reindex_records"] == max(100000 / 10, 1000) assert result["hnsw_params"]["batching_params"]["reindex_interval"] == 30000 - assert result["storage"]["namespace"] == "test" - assert result["storage"]["set_name"] == random_name + assert result["storage"]["namespace"] == DEFAULT_NAMESPACE + assert result["storage"]["set_name"] == index # Defaults assert result["sets"] == "" @@ -58,26 +43,18 @@ def test_index_get(session_admin_client, empty_test_case, random_name): # assert result["hnsw_params"]["merge_params"]["index_parallelism"] == 80 # assert result["hnsw_params"]["merge_params"]["reindex_parallelism"] == 26 - drop_specified_index(session_admin_client, "test", random_name) - @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -async def test_index_get_no_defaults(session_admin_client, empty_test_case, random_name): - session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="science", - dimensions=1024, - ) +async def test_index_get_no_defaults(session_admin_client, empty_test_case, index): - result = session_admin_client.index_get(namespace="test", name=random_name, apply_defaults=False) + result = session_admin_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=False) - assert result["id"]["name"] == random_name - assert result["id"]["namespace"] == "test" - assert result["dimensions"] == 1024 - assert result["field"] == "science" + assert result["id"]["name"] == index + assert result["id"]["namespace"] == DEFAULT_NAMESPACE + assert result["dimensions"] == DEFAULT_INDEX_DIMENSION + assert result["field"] == DEFAULT_VECTOR_FIELD # Defaults assert result["sets"] == "" @@ -105,32 +82,20 @@ async def test_index_get_no_defaults(session_admin_client, empty_test_case, rand assert result["storage"].set_name == "" assert result["storage"]["set_name"] == "" - drop_specified_index(session_admin_client, "test", random_name) - @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) def test_index_get_timeout( - session_admin_client, empty_test_case, random_name, with_latency + session_admin_client, empty_test_case, index, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") - try: - session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="science", - dimensions=1024, - ) - except AVSServerError as se: - if se.rpc_error.code() != grpc.StatusCode.ALREADY_EXISTS: - raise se for i in range(10): try: result = session_admin_client.index_get( - namespace="test", name=random_name, timeout=0.0001 + namespace=DEFAULT_NAMESPACE, name=index, timeout=0.0001 ) except AVSServerError as se: diff --git a/tests/standard/sync/test_admin_client_index_get_status.py b/tests/standard/sync/test_admin_client_index_get_status.py index 79a99dda..0f127829 100644 --- a/tests/standard/sync/test_admin_client_index_get_status.py +++ b/tests/standard/sync/test_admin_client_index_get_status.py @@ -1,7 +1,7 @@ import pytest import grpc -from ...utils import random_name +from ...utils import DEFAULT_NAMESPACE from .sync_utils import drop_specified_index from hypothesis import given, settings, Verbosity @@ -13,28 +13,18 @@ @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -def test_index_get_status(session_admin_client, empty_test_case, random_name): - try: - session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="science", - dimensions=1024, - ) - except AVSServerError as se: - if se.rpc_error.code() != grpc.StatusCode.ALREADY_EXISTS: - raise se - result = session_admin_client.index_get_status(namespace="test", name=random_name) +def test_index_get_status(session_admin_client, empty_test_case, index): + result = session_admin_client.index_get_status(namespace=DEFAULT_NAMESPACE, name=index) assert result.unmerged_record_count == 0 - drop_specified_index(session_admin_client, "test", random_name) + drop_specified_index(session_admin_client, DEFAULT_NAMESPACE, index) @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) def test_index_get_status_timeout( - session_admin_client, empty_test_case, random_name, with_latency + session_admin_client, empty_test_case, index, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") @@ -42,7 +32,7 @@ def test_index_get_status_timeout( for i in range(10): try: result = session_admin_client.index_get_status( - namespace="test", name=random_name, timeout=0.0001 + namespace=DEFAULT_NAMESPACE, name=index, timeout=0.0001 ) except AVSServerError as se: if se.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: diff --git a/tests/standard/sync/test_admin_client_index_list.py b/tests/standard/sync/test_admin_client_index_list.py index b2a2c9e7..57cad194 100644 --- a/tests/standard/sync/test_admin_client_index_list.py +++ b/tests/standard/sync/test_admin_client_index_list.py @@ -1,24 +1,15 @@ from aerospike_vector_search import AVSServerError +from .sync_utils import drop_specified_index import pytest import grpc - -from ...utils import random_name - -from .sync_utils import drop_specified_index from hypothesis import given, settings, Verbosity @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -def test_index_list(session_admin_client, empty_test_case, random_name): - session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="science", - dimensions=1024, - ) +def test_index_list(session_admin_client, empty_test_case, index): result = session_admin_client.index_list(apply_defaults=True) assert len(result) > 0 for index in result: @@ -35,28 +26,18 @@ def test_index_list(session_admin_client, empty_test_case, random_name): assert isinstance(index["hnsw_params"]["batching_params"]["reindex_interval"], int) assert isinstance(index["storage"]["namespace"], str) assert isinstance(index["storage"]["set_name"], str) - drop_specified_index(session_admin_client, "test", random_name) + drop_specified_index(session_admin_client, "test", index) @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) def test_index_list_timeout( - session_admin_client, empty_test_case, random_name, with_latency + session_admin_client, empty_test_case, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") - try: - session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="science", - dimensions=1024, - ) - except AVSServerError as se: - if se.rpc_error.code() != grpc.StatusCode.ALREADY_EXISTS: - raise se for i in range(10): diff --git a/tests/standard/sync/test_admin_client_index_update.py b/tests/standard/sync/test_admin_client_index_update.py index 77ee984d..f682ff71 100644 --- a/tests/standard/sync/test_admin_client_index_update.py +++ b/tests/standard/sync/test_admin_client_index_update.py @@ -1,16 +1,14 @@ import time -import pytest -from aerospike_vector_search import types, AVSServerError -import grpc -from .sync_utils import drop_specified_index +from aerospike_vector_search import types +from utils import DEFAULT_NAMESPACE +import pytest class index_update_test_case: def __init__( self, *, - namespace, vector_field, dimensions, initial_labels, @@ -18,7 +16,6 @@ def __init__( hnsw_index_update, timeout ): - self.namespace = namespace self.vector_field = vector_field self.dimensions = dimensions self.initial_labels = initial_labels @@ -31,7 +28,6 @@ def __init__( "test_case", [ index_update_test_case( - namespace="test", vector_field="update_2", dimensions=256, initial_labels={"status": "active"}, @@ -48,30 +44,11 @@ def __init__( ), ], ) -def test_index_update(session_admin_client, test_case): - trimmed_random = "saUEN1-" - - # Drop any pre-existing index with the same name - try: - session_admin_client.index_drop(namespace="test", name=trimmed_random) - except AVSServerError as se: - if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: - pass - - # Create the index - session_admin_client.index_create( - namespace=test_case.namespace, - name=trimmed_random, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - index_labels=test_case.initial_labels, - timeout=test_case.timeout, - ) - +def test_index_update(session_admin_client, test_case, index): # Update the index with parameters based on the test case session_admin_client.index_update( namespace=test_case.namespace, - name=trimmed_random, + name=index, index_labels=test_case.update_labels, hnsw_update_params=test_case.hnsw_index_update, timeout=100_000, @@ -80,7 +57,7 @@ def test_index_update(session_admin_client, test_case): time.sleep(10) # Verify the update - result = session_admin_client.index_get(namespace=test_case.namespace, name=trimmed_random, apply_defaults=True) + result = session_admin_client.index_get(namespace=test_case.namespace, name=index, apply_defaults=True) assert result, "Expected result to be non-empty but got an empty dictionary." # Assertions @@ -104,6 +81,3 @@ def test_index_update(session_admin_client, test_case): assert result["hnsw_params"]["healer_params"]["max_scan_rate_per_node"] == test_case.hnsw_index_update.healer_params.max_scan_rate_per_node assert result["hnsw_params"]["enable_vector_integrity_check"] == test_case.hnsw_index_update.enable_vector_integrity_check - - # Clean up by dropping the index after the test - drop_specified_index(session_admin_client, test_case.namespace, trimmed_random) diff --git a/tests/standard/sync/test_vector_client_delete.py b/tests/standard/sync/test_vector_client_delete.py index 60938a09..50f203d2 100644 --- a/tests/standard/sync/test_vector_client_delete.py +++ b/tests/standard/sync/test_vector_client_delete.py @@ -3,7 +3,7 @@ import grpc from aerospike_vector_search import AVSServerError -from ...utils import random_key +from utils import random_key, DEFAULT_NAMESPACE from hypothesis import given, settings, Verbosity @@ -13,13 +13,11 @@ def __init__( self, *, namespace, - record_data, set_name, timeout, ): self.namespace = namespace self.set_name = set_name - self.record_data = record_data self.timeout = timeout @@ -29,33 +27,25 @@ def __init__( "test_case", [ delete_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, set_name=None, - record_data={"skills": [i for i in range(1024)]}, timeout=None, ), delete_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, set_name=None, - record_data={"english": [float(i) for i in range(1024)]}, timeout=None, ), ], ) -def test_vector_delete(session_vector_client, test_case, random_key): - session_vector_client.upsert( - namespace=test_case.namespace, - key=random_key, - record_data=test_case.record_data, - set_name=test_case.set_name, - ) +def test_vector_delete(session_vector_client, test_case, record): session_vector_client.delete( namespace=test_case.namespace, - key=random_key, + key=record, ) with pytest.raises(AVSServerError) as e_info: result = session_vector_client.get( - namespace=test_case.namespace, key=random_key + namespace=test_case.namespace, key=record ) @@ -65,9 +55,8 @@ def test_vector_delete(session_vector_client, test_case, random_key): "test_case", [ delete_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, set_name=None, - record_data={"skills": [i for i in range(1024)]}, timeout=None, ), ], @@ -85,15 +74,14 @@ def test_vector_delete_without_record(session_vector_client, test_case, random_k "test_case", [ delete_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, set_name=None, - record_data={"skills": [i for i in range(1024)]}, timeout=0.0001, ), ], ) def test_vector_delete_timeout( - session_vector_client, test_case, random_key, with_latency + session_vector_client, test_case, record, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") @@ -101,7 +89,7 @@ def test_vector_delete_timeout( for i in range(10): try: session_vector_client.delete( - namespace=test_case.namespace, key=random_key, timeout=test_case.timeout + namespace=test_case.namespace, key=record, timeout=test_case.timeout ) except AVSServerError as se: if se.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: diff --git a/tests/standard/sync/test_vector_client_exists.py b/tests/standard/sync/test_vector_client_exists.py index fbacf1ca..5be07316 100644 --- a/tests/standard/sync/test_vector_client_exists.py +++ b/tests/standard/sync/test_vector_client_exists.py @@ -1,7 +1,7 @@ import pytest import grpc -from ...utils import random_key +from utils import DEFAULT_NAMESPACE from hypothesis import given, settings, Verbosity from aerospike_vector_search import types, AVSServerError @@ -12,13 +12,11 @@ def __init__( self, *, namespace, - record_data, set_name, timeout, ): self.namespace = namespace self.set_name = set_name - self.record_data = record_data self.timeout = timeout @@ -28,38 +26,24 @@ def __init__( "test_case", [ exists_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, set_name=None, - record_data={"skills": [i for i in range(1024)]}, timeout=None, ), exists_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, set_name=None, - record_data={"english": [float(i) for i in range(1024)]}, timeout=None, ), ], ) -def test_vector_exists(session_vector_client, test_case, random_key): - session_vector_client.upsert( - namespace=test_case.namespace, - key=random_key, - record_data=test_case.record_data, - set_name=test_case.set_name, - timeout=None, - ) +def test_vector_exists(session_vector_client, test_case, record): result = session_vector_client.exists( namespace=test_case.namespace, - key=random_key, + key=record, ) assert result is True - session_vector_client.delete( - namespace=test_case.namespace, - key=random_key, - ) - #@given(random_key=key_strategy()) #@settings(max_examples=1, deadline=1000) @@ -67,12 +51,12 @@ def test_vector_exists(session_vector_client, test_case, random_key): "test_case", [ exists_test_case( - namespace="test", set_name=None, record_data=None, timeout=0.0001 + namespace=DEFAULT_NAMESPACE, set_name=None, timeout=0.0001 ), ], ) def test_vector_exists_timeout( - session_vector_client, test_case, random_key, with_latency + session_vector_client, test_case, record, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") @@ -80,7 +64,7 @@ def test_vector_exists_timeout( for i in range(10): try: result = session_vector_client.exists( - namespace=test_case.namespace, key=random_key, timeout=test_case.timeout + namespace=test_case.namespace, key=record, timeout=test_case.timeout ) except AVSServerError as se: if se.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: diff --git a/tests/standard/sync/test_vector_client_get.py b/tests/standard/sync/test_vector_client_get.py index ab26c1f1..3dd85e67 100644 --- a/tests/standard/sync/test_vector_client_get.py +++ b/tests/standard/sync/test_vector_client_get.py @@ -1,11 +1,10 @@ +from aerospike_vector_search import types, AVSServerError +from utils import DEFAULT_NAMESPACE, random_key + import pytest import grpc -from ...utils import random_key - from hypothesis import given, settings, Verbosity -from aerospike_vector_search import types, AVSServerError - class get_test_case: def __init__( @@ -15,7 +14,6 @@ def __init__( include_fields, exclude_fields, set_name, - record_data, expected_fields, timeout, ): @@ -23,7 +21,6 @@ def __init__( self.include_fields = include_fields self.exclude_fields = exclude_fields self.set_name = set_name - self.record_data = record_data self.expected_fields = expected_fields self.timeout = timeout @@ -31,80 +28,89 @@ def __init__( #@given(random_key=key_strategy()) #@settings(max_examples=1, deadline=1000) @pytest.mark.parametrize( - "test_case", + "record,test_case", [ - get_test_case( - namespace="test", - include_fields=["skills"], - exclude_fields=None, - set_name=None, - record_data={"skills": [i for i in range(1024)]}, - expected_fields={"skills": [i for i in range(1024)]}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"skills": 1024}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=["skills"], + exclude_fields=None, + set_name=None, + expected_fields={"skills": 1024}, + timeout=None, + ), ), - get_test_case( - namespace="test", - include_fields=["english"], - exclude_fields=None, - set_name=None, - record_data={"english": [float(i) for i in range(1024)]}, - expected_fields={"english": [float(i) for i in range(1024)]}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": [float(i) for i in range(1024)]}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=["english"], + exclude_fields=None, + set_name=None, + expected_fields={"english": [float(i) for i in range(1024)]}, + timeout=None, + ), ), - get_test_case( - namespace="test", - include_fields=["english"], - exclude_fields = None, - set_name=None, - record_data={"english": 1, "spanish": 2}, - expected_fields={"english": 1}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": 1, "spanish": 2}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=["english"], + exclude_fields=None, + set_name=None, + expected_fields={"english": 1}, + timeout=None, + ), ), - get_test_case( - namespace="test", - include_fields=None, - exclude_fields=["spanish"], - set_name=None, - record_data={"english": 1, "spanish": 2}, - expected_fields={"english": 1}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": 1, "spanish": 2}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=None, + exclude_fields=["spanish"], + set_name=None, + expected_fields={"english": 1}, + timeout=None, + ), ), - get_test_case( - namespace="test", - include_fields=["spanish"], - exclude_fields=["spanish"], - set_name=None, - record_data={"english": 1, "spanish": 2}, - expected_fields={}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": 1, "spanish": 2}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=["spanish"], + exclude_fields=["spanish"], + set_name=None, + expected_fields={}, + timeout=None, + ), ), - get_test_case( - namespace="test", - include_fields=[], - exclude_fields=None, - set_name=None, - record_data={"english": 1, "spanish": 2}, - expected_fields={}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": 1, "spanish": 2}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=[], + exclude_fields=None, + set_name=None, + expected_fields={}, + timeout=None, + ), ), - get_test_case( - namespace="test", - include_fields=None, - exclude_fields=[], - set_name=None, - record_data={"english": 1, "spanish": 2}, - expected_fields={"english": 1, "spanish": 2}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": 1, "spanish": 2}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=None, + exclude_fields=[], + set_name=None, + expected_fields={"english": 1, "spanish": 2}, + timeout=None, + ), ), ], + indirect=["record"], ) -def test_vector_get(session_vector_client, test_case, random_key): - session_vector_client.upsert( - namespace=test_case.namespace, - key=random_key, - record_data=test_case.record_data, - set_name=test_case.set_name, - ) +def test_vector_get(session_vector_client, test_case, random_key, record_data): result = session_vector_client.get( namespace=test_case.namespace, key=random_key, @@ -112,18 +118,13 @@ def test_vector_get(session_vector_client, test_case, random_key): exclude_fields=test_case.exclude_fields, ) assert result.key.namespace == test_case.namespace - if test_case.set_name == None: + if test_case.set_name is None: test_case.set_name = "" assert result.key.set == test_case.set_name assert result.key.key == random_key assert result.fields == test_case.expected_fields - session_vector_client.delete( - namespace=test_case.namespace, - key=random_key, - ) - #@given(random_key=key_strategy()) #@settings(max_examples=1, deadline=1000) @@ -131,11 +132,10 @@ def test_vector_get(session_vector_client, test_case, random_key): "test_case", [ get_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, include_fields=["skills"], exclude_fields=None, set_name=None, - record_data=None, expected_fields=None, timeout=0.0001, ), diff --git a/tests/standard/sync/test_vector_client_insert.py b/tests/standard/sync/test_vector_client_insert.py index 5c1a39e0..b763b9fa 100644 --- a/tests/standard/sync/test_vector_client_insert.py +++ b/tests/standard/sync/test_vector_client_insert.py @@ -2,7 +2,7 @@ from aerospike_vector_search import AVSServerError import grpc -from ...utils import random_key +from utils import random_key, DEFAULT_NAMESPACE from hypothesis import given, settings, Verbosity @@ -26,21 +26,21 @@ def __init__( "test_case", [ insert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, ignore_mem_queue_full=None, timeout=None, ), insert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"homeSkills": [float(i) for i in range(1024)]}, set_name=None, ignore_mem_queue_full=None, timeout=None, ), insert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"english": [bool(i) for i in range(1024)]}, set_name=None, ignore_mem_queue_full=None, @@ -75,7 +75,7 @@ def test_vector_insert_without_existing_record( "test_case", [ insert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, ignore_mem_queue_full=None, @@ -84,25 +84,15 @@ def test_vector_insert_without_existing_record( ], ) def test_vector_insert_with_existing_record( - session_vector_client, test_case, random_key + session_vector_client, test_case, record ): - session_vector_client.upsert( - namespace=test_case.namespace, - key=random_key, - record_data=test_case.record_data, - set_name=test_case.set_name, - ) with pytest.raises(AVSServerError) as e_info: session_vector_client.insert( namespace=test_case.namespace, - key=random_key, + key=record, record_data=test_case.record_data, set_name=test_case.set_name, ) - session_vector_client.delete( - namespace=test_case.namespace, - key=random_key, - ) #@given(random_key=key_strategy()) @@ -111,7 +101,7 @@ def test_vector_insert_with_existing_record( "test_case", [ insert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"english": [bool(i) for i in range(1024)]}, set_name=None, ignore_mem_queue_full=True, @@ -147,7 +137,7 @@ def test_vector_insert_without_existing_record_ignore_mem_queue_full( "test_case", [ insert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, ignore_mem_queue_full=None, @@ -170,6 +160,10 @@ def test_vector_insert_timeout( set_name=test_case.set_name, timeout=test_case.timeout, ) + session_vector_client.delete( + namespace=test_case.namespace, + key=random_key, + ) except AVSServerError as e: if e.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: assert e.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED diff --git a/tests/standard/sync/test_vector_client_is_indexed.py b/tests/standard/sync/test_vector_client_is_indexed.py index 4f761957..d1a61c2e 100644 --- a/tests/standard/sync/test_vector_client_is_indexed.py +++ b/tests/standard/sync/test_vector_client_is_indexed.py @@ -1,76 +1,46 @@ import pytest -import random + from aerospike_vector_search import AVSServerError -from aerospike_vector_search import Client, AdminClient +from utils import DEFAULT_NAMESPACE +from .sync_utils import wait_for_index import grpc -# Define module-level constants for common arguments -NAMESPACE = "test" -SET_NAME = "isidxset" -INDEX_NAME = "isidx" -DIMENSIONS = 3 -VECTOR_FIELD = "vector" -@pytest.fixture(scope="module", autouse=True) -async def setup_teardown(session_vector_client: Client, session_admin_client: AdminClient): - # Setup: Create index and upsert records - session_admin_client.index_create( - namespace=NAMESPACE, - sets=SET_NAME, - name=INDEX_NAME, - dimensions=DIMENSIONS, - vector_field=VECTOR_FIELD, - ) - for i in range(10): - session_vector_client.upsert( - namespace=NAMESPACE, - set_name=SET_NAME, - key=str(i), - record_data={VECTOR_FIELD: [float(i)] * DIMENSIONS}, - ) - session_vector_client.wait_for_index_completion( - namespace=NAMESPACE, - name=INDEX_NAME, - ) - yield - # Teardown: remove records - for i in range(10): - session_vector_client.delete( - namespace=NAMESPACE, - set_name=SET_NAME, - key=str(i) - ) - - # Teardown: Drop index - session_admin_client.index_drop( - namespace=NAMESPACE, - name=INDEX_NAME - ) - -async def test_vector_is_indexed( +def test_vector_is_indexed( + session_admin_client, session_vector_client, + index, + record, ): + # wait for the record to be indexed + wait_for_index( + admin_client=session_admin_client, + namespace=DEFAULT_NAMESPACE, + index=index + ) + result = session_vector_client.is_indexed( - namespace=NAMESPACE, + namespace=DEFAULT_NAMESPACE, key="0", - index_name=INDEX_NAME, - set_name=SET_NAME, + index_name=index, ) assert result is True -async def test_vector_is_indexed_timeout( - session_vector_client, with_latency +def test_vector_is_indexed_timeout( + session_vector_client, + with_latency, + random_name, ): if not with_latency: pytest.skip("Server latency too low to test timeout") for _ in range(10): try: session_vector_client.is_indexed( - namespace=NAMESPACE, + namespace=DEFAULT_NAMESPACE, key="0", - index_name=INDEX_NAME, + index_name=random_name, timeout=0.0001, ) except AVSServerError as se: diff --git a/tests/standard/sync/test_vector_client_update.py b/tests/standard/sync/test_vector_client_update.py index ef111b49..c11fe643 100644 --- a/tests/standard/sync/test_vector_client_update.py +++ b/tests/standard/sync/test_vector_client_update.py @@ -41,19 +41,11 @@ def __init__(self, *, namespace, record_data, set_name, timeout): ], ) def test_vector_update_with_existing_record( - session_vector_client, test_case, random_key + session_vector_client, test_case, record ): - session_vector_client.upsert( - namespace=test_case.namespace, - key=random_key, - record_data=test_case.record_data, - set_name=test_case.set_name, - timeout=None, - ) - session_vector_client.update( namespace=test_case.namespace, - key=random_key, + key=record, record_data=test_case.record_data, set_name=test_case.set_name, timeout=None, From 80a32a79acd20d6a8cbb5a6b7119527544f68b0a Mon Sep 17 00:00:00 2001 From: dylan Date: Fri, 20 Dec 2024 16:18:53 -0800 Subject: [PATCH 05/21] fix async tests --- tests/standard/aio/aio_utils.py | 2 +- .../aio/test_admin_client_index_get.py | 10 +- tests/standard/aio/test_service_config.py | 992 +++++++++--------- tests/standard/aio/test_vector_client_get.py | 17 - .../aio/test_vector_client_search_by_key.py | 20 +- tests/standard/aio/test_vector_search.py | 10 +- .../sync/test_admin_client_index_get.py | 10 +- 7 files changed, 531 insertions(+), 530 deletions(-) diff --git a/tests/standard/aio/aio_utils.py b/tests/standard/aio/aio_utils.py index 867ec109..2f39646b 100644 --- a/tests/standard/aio/aio_utils.py +++ b/tests/standard/aio/aio_utils.py @@ -2,7 +2,7 @@ async def drop_specified_index(admin_client, namespace, name): - admin_client.index_drop(namespace=namespace, name=name) + await admin_client.index_drop(namespace=namespace, name=name) def gen_records(count: int, vec_bin: str, vec_dim: int): diff --git a/tests/standard/aio/test_admin_client_index_get.py b/tests/standard/aio/test_admin_client_index_get.py index b584da8a..ff21b3e9 100644 --- a/tests/standard/aio/test_admin_client_index_get.py +++ b/tests/standard/aio/test_admin_client_index_get.py @@ -21,7 +21,7 @@ async def test_index_get(session_admin_client, empty_test_case, index): assert result["hnsw_params"]["ef_construction"] == 100 assert result["hnsw_params"]["ef"] == 100 assert result["hnsw_params"]["batching_params"]["max_index_records"] == 100000 - assert result["hnsw_params"]["batching_params"]["index_interval"] == 30000 + assert result["hnsw_params"]["batching_params"]["index_interval"] == 10000 assert result["hnsw_params"]["batching_params"]["max_reindex_records"] == max(100000 / 10, 1000) assert result["hnsw_params"]["batching_params"]["reindex_interval"] == 30000 assert result["storage"]["namespace"] == DEFAULT_NAMESPACE @@ -39,7 +39,7 @@ async def test_index_get(session_admin_client, empty_test_case, index): assert result["hnsw_params"]["healer_params"]["max_scan_rate_per_node"] == 1000 assert result["hnsw_params"]["healer_params"]["max_scan_page_size"] == 10000 assert result["hnsw_params"]["healer_params"]["re_index_percent"] == 10.0 - assert result["hnsw_params"]["healer_params"]["schedule"] == "0 0/15 * ? * * *" + assert result["hnsw_params"]["healer_params"]["schedule"] == "* * * * * ?" assert result["hnsw_params"]["healer_params"]["parallelism"] == 1 # index parallelism and reindex parallelism are dynamic depending on the CPU cores of the host @@ -66,7 +66,8 @@ async def test_index_get_no_defaults(session_admin_client, empty_test_case, inde assert result["hnsw_params"]["ef"] == 0 assert result["hnsw_params"]["ef_construction"] == 0 assert result["hnsw_params"]["batching_params"]["max_index_records"] == 0 - assert result["hnsw_params"]["batching_params"]["index_interval"] == 0 + # This is set by default to 10000 in the index fixture + assert result["hnsw_params"]["batching_params"]["index_interval"] == 10000 assert result["hnsw_params"]["max_mem_queue_size"] == 0 assert result["hnsw_params"]["index_caching_params"]["max_entries"] == 0 assert result["hnsw_params"]["index_caching_params"]["expiry"] == 0 @@ -74,7 +75,8 @@ async def test_index_get_no_defaults(session_admin_client, empty_test_case, inde assert result["hnsw_params"]["healer_params"]["max_scan_rate_per_node"] == 0 assert result["hnsw_params"]["healer_params"]["max_scan_page_size"] == 0 assert result["hnsw_params"]["healer_params"]["re_index_percent"] == 0 - assert result["hnsw_params"]["healer_params"]["schedule"] == "" + # This is set by default to * * * * * ? in the index fixture + assert result["hnsw_params"]["healer_params"]["schedule"] == "* * * * * ?" assert result["hnsw_params"]["healer_params"]["parallelism"] == 0 assert result["hnsw_params"]["merge_params"]["index_parallelism"] == 0 diff --git a/tests/standard/aio/test_service_config.py b/tests/standard/aio/test_service_config.py index 66277e57..bcc94eda 100644 --- a/tests/standard/aio/test_service_config.py +++ b/tests/standard/aio/test_service_config.py @@ -1,496 +1,496 @@ -import pytest -import time - -import os -import json - -from aerospike_vector_search import AVSServerError, types -from aerospike_vector_search.aio import AdminClient - - -class service_config_parse_test_case: - def __init__(self, *, service_config_path): - self.service_config_path = service_config_path - - -@pytest.mark.parametrize( - "test_case", - [ - service_config_parse_test_case( - service_config_path="service_configs/master.json" - ), - ], -) -async def test_admin_client_service_config_parse( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - test_case, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - async with AdminClient( - seeds=types.HostPort(host=host, port=port), - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override, - service_config_path=test_case.service_config_path, - ) as client: - pass - - -class service_config_test_case: - def __init__( - self, *, service_config_path, namespace, name, vector_field, dimensions - ): - - script_dir = os.path.dirname(os.path.abspath(__file__)) - - self.service_config_path = os.path.abspath( - os.path.join(script_dir, "..", "..", service_config_path) - ) - - with open(self.service_config_path, "rb") as f: - self.service_config = json.load(f) - - self.max_attempts = self.service_config["methodConfig"][0]["retryPolicy"][ - "maxAttempts" - ] - self.initial_backoff = int( - self.service_config["methodConfig"][0]["retryPolicy"]["initialBackoff"][:-1] - ) - self.max_backoff = int( - self.service_config["methodConfig"][0]["retryPolicy"]["maxBackoff"][:-1] - ) - self.backoff_multiplier = self.service_config["methodConfig"][0]["retryPolicy"][ - "backoffMultiplier" - ] - self.retryable_status_codes = self.service_config["methodConfig"][0][ - "retryPolicy" - ]["retryableStatusCodes"] - self.namespace = namespace - self.name = name - self.vector_field = vector_field - self.dimensions = dimensions - - -def calculate_expected_time( - max_attempts, - initial_backoff, - backoff_multiplier, - max_backoff, - retryable_status_codes, -): - - current_backkoff = initial_backoff - - expected_time = 0 - for attempt in range(max_attempts - 1): - expected_time += current_backkoff - current_backkoff *= backoff_multiplier - current_backkoff = min(current_backkoff, max_backoff) - - return expected_time - - -@pytest.mark.parametrize( - "test_case", - [ - service_config_test_case( - service_config_path="service_configs/retries.json", - namespace="test", - name="service_config_index_1", - vector_field="example_1", - dimensions=1024, - ) - ], -) -async def test_admin_client_service_config_retries( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - test_case, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - async with AdminClient( - seeds=types.HostPort(host=host, port=port), - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override, - service_config_path=test_case.service_config_path, - ) as client: - try: - await client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - except: - pass - expected_time = calculate_expected_time( - test_case.max_attempts, - test_case.initial_backoff, - test_case.backoff_multiplier, - test_case.max_backoff, - test_case.retryable_status_codes, - ) - start_time = time.time() - - with pytest.raises(AVSServerError) as e_info: - await client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - - end_time = time.time() - elapsed_time = end_time - start_time - - assert abs(elapsed_time - expected_time) < 1.5 - - -@pytest.mark.parametrize( - "test_case", - [ - service_config_test_case( - service_config_path="service_configs/initial_backoff.json", - namespace="test", - name="service_config_index_2", - vector_field="example_1", - dimensions=1024, - ) - ], -) -async def test_admin_client_service_config_initial_backoff( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - test_case, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - async with AdminClient( - seeds=types.HostPort(host=host, port=port), - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override, - service_config_path=test_case.service_config_path, - ) as client: - - try: - await client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - except: - pass - - expected_time = calculate_expected_time( - test_case.max_attempts, - test_case.initial_backoff, - test_case.backoff_multiplier, - test_case.max_backoff, - test_case.retryable_status_codes, - ) - start_time = time.time() - - with pytest.raises(AVSServerError) as e_info: - await client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - - end_time = time.time() - elapsed_time = end_time - start_time - - assert abs(elapsed_time - expected_time) < 1.5 - - -@pytest.mark.parametrize( - "test_case", - [ - service_config_test_case( - service_config_path="service_configs/max_backoff.json", - namespace="test", - name="service_config_index_3", - vector_field="example_1", - dimensions=1024, - ), - service_config_test_case( - service_config_path="service_configs/max_backoff_lower_than_initial.json", - namespace="test", - name="service_config_index_4", - vector_field="example_1", - dimensions=1024, - ), - ], -) -async def test_admin_client_service_config_max_backoff( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - test_case, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - async with AdminClient( - seeds=types.HostPort(host=host, port=port), - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override, - service_config_path=test_case.service_config_path, - ) as client: - - try: - await client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - except: - pass - expected_time = calculate_expected_time( - test_case.max_attempts, - test_case.initial_backoff, - test_case.backoff_multiplier, - test_case.max_backoff, - test_case.retryable_status_codes, - ) - start_time = time.time() - - with pytest.raises(AVSServerError) as e_info: - await client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - - end_time = time.time() - elapsed_time = end_time - start_time - assert abs(elapsed_time - expected_time) < 1.5 - - -@pytest.mark.parametrize( - "test_case", - [ - service_config_test_case( - service_config_path="service_configs/backoff_multiplier.json", - namespace="test", - name="service_config_index_5", - vector_field="example_1", - dimensions=1024, - ) - ], -) -async def test_admin_client_service_config_backoff_multiplier( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - test_case, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - async with AdminClient( - seeds=types.HostPort(host=host, port=port), - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override, - service_config_path=test_case.service_config_path, - ) as client: - - try: - - await client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - except: - pass - - expected_time = calculate_expected_time( - test_case.max_attempts, - test_case.initial_backoff, - test_case.backoff_multiplier, - test_case.max_backoff, - test_case.retryable_status_codes, - ) - start_time = time.time() - - with pytest.raises(AVSServerError) as e_info: - await client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - - end_time = time.time() - elapsed_time = end_time - start_time - assert abs(elapsed_time - expected_time) < 1.5 - - -@pytest.mark.parametrize( - "test_case", - [ - service_config_test_case( - service_config_path="service_configs/retryable_status_codes.json", - namespace="test", - name="service_config_index_6", - vector_field=None, - dimensions=None, - ) - ], -) -async def test_admin_client_service_config_retryable_status_codes( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - test_case, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - async with AdminClient( - seeds=types.HostPort(host=host, port=port), - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override, - service_config_path=test_case.service_config_path, - ) as client: - - expected_time = calculate_expected_time( - test_case.max_attempts, - test_case.initial_backoff, - test_case.backoff_multiplier, - test_case.max_backoff, - test_case.retryable_status_codes, - ) - start_time = time.time() - - with pytest.raises(AVSServerError) as e_info: - await client.index_get_status( - namespace=test_case.namespace, - name=test_case.name, - ) - - end_time = time.time() - elapsed_time = end_time - start_time - assert abs(elapsed_time - expected_time) < 2 +# import pytest +# import time + +# import os +# import json + +# from aerospike_vector_search import AVSServerError, types +# from aerospike_vector_search.aio import AdminClient + + +# class service_config_parse_test_case: +# def __init__(self, *, service_config_path): +# self.service_config_path = service_config_path + + +# @pytest.mark.parametrize( +# "test_case", +# [ +# service_config_parse_test_case( +# service_config_path="service_configs/master.json" +# ), +# ], +# ) +# async def test_admin_client_service_config_parse( +# host, +# port, +# username, +# password, +# root_certificate, +# certificate_chain, +# private_key, +# ssl_target_name_override, +# test_case, +# ): + +# if root_certificate: +# with open(root_certificate, "rb") as f: +# root_certificate = f.read() + +# if certificate_chain: +# with open(certificate_chain, "rb") as f: +# certificate_chain = f.read() +# if private_key: +# with open(private_key, "rb") as f: +# private_key = f.read() + +# async with AdminClient( +# seeds=types.HostPort(host=host, port=port), +# username=username, +# password=password, +# root_certificate=root_certificate, +# certificate_chain=certificate_chain, +# private_key=private_key, +# ssl_target_name_override=ssl_target_name_override, +# service_config_path=test_case.service_config_path, +# ) as client: +# pass + + +# class service_config_test_case: +# def __init__( +# self, *, service_config_path, namespace, name, vector_field, dimensions +# ): + +# script_dir = os.path.dirname(os.path.abspath(__file__)) + +# self.service_config_path = os.path.abspath( +# os.path.join(script_dir, "..", "..", service_config_path) +# ) + +# with open(self.service_config_path, "rb") as f: +# self.service_config = json.load(f) + +# self.max_attempts = self.service_config["methodConfig"][0]["retryPolicy"][ +# "maxAttempts" +# ] +# self.initial_backoff = int( +# self.service_config["methodConfig"][0]["retryPolicy"]["initialBackoff"][:-1] +# ) +# self.max_backoff = int( +# self.service_config["methodConfig"][0]["retryPolicy"]["maxBackoff"][:-1] +# ) +# self.backoff_multiplier = self.service_config["methodConfig"][0]["retryPolicy"][ +# "backoffMultiplier" +# ] +# self.retryable_status_codes = self.service_config["methodConfig"][0][ +# "retryPolicy" +# ]["retryableStatusCodes"] +# self.namespace = namespace +# self.name = name +# self.vector_field = vector_field +# self.dimensions = dimensions + + +# def calculate_expected_time( +# max_attempts, +# initial_backoff, +# backoff_multiplier, +# max_backoff, +# retryable_status_codes, +# ): + +# current_backkoff = initial_backoff + +# expected_time = 0 +# for attempt in range(max_attempts - 1): +# expected_time += current_backkoff +# current_backkoff *= backoff_multiplier +# current_backkoff = min(current_backkoff, max_backoff) + +# return expected_time + + +# @pytest.mark.parametrize( +# "test_case", +# [ +# service_config_test_case( +# service_config_path="service_configs/retries.json", +# namespace="test", +# name="service_config_index_1", +# vector_field="example_1", +# dimensions=1024, +# ) +# ], +# ) +# async def test_admin_client_service_config_retries( +# host, +# port, +# username, +# password, +# root_certificate, +# certificate_chain, +# private_key, +# ssl_target_name_override, +# test_case, +# ): + +# if root_certificate: +# with open(root_certificate, "rb") as f: +# root_certificate = f.read() + +# if certificate_chain: +# with open(certificate_chain, "rb") as f: +# certificate_chain = f.read() +# if private_key: +# with open(private_key, "rb") as f: +# private_key = f.read() + +# async with AdminClient( +# seeds=types.HostPort(host=host, port=port), +# username=username, +# password=password, +# root_certificate=root_certificate, +# certificate_chain=certificate_chain, +# private_key=private_key, +# ssl_target_name_override=ssl_target_name_override, +# service_config_path=test_case.service_config_path, +# ) as client: +# try: +# await client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) +# except: +# pass +# expected_time = calculate_expected_time( +# test_case.max_attempts, +# test_case.initial_backoff, +# test_case.backoff_multiplier, +# test_case.max_backoff, +# test_case.retryable_status_codes, +# ) +# start_time = time.time() + +# with pytest.raises(AVSServerError) as e_info: +# await client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) + +# end_time = time.time() +# elapsed_time = end_time - start_time + +# assert abs(elapsed_time - expected_time) < 1.5 + + +# @pytest.mark.parametrize( +# "test_case", +# [ +# service_config_test_case( +# service_config_path="service_configs/initial_backoff.json", +# namespace="test", +# name="service_config_index_2", +# vector_field="example_1", +# dimensions=1024, +# ) +# ], +# ) +# async def test_admin_client_service_config_initial_backoff( +# host, +# port, +# username, +# password, +# root_certificate, +# certificate_chain, +# private_key, +# ssl_target_name_override, +# test_case, +# ): + +# if root_certificate: +# with open(root_certificate, "rb") as f: +# root_certificate = f.read() + +# if certificate_chain: +# with open(certificate_chain, "rb") as f: +# certificate_chain = f.read() +# if private_key: +# with open(private_key, "rb") as f: +# private_key = f.read() + +# async with AdminClient( +# seeds=types.HostPort(host=host, port=port), +# username=username, +# password=password, +# root_certificate=root_certificate, +# certificate_chain=certificate_chain, +# private_key=private_key, +# ssl_target_name_override=ssl_target_name_override, +# service_config_path=test_case.service_config_path, +# ) as client: + +# try: +# await client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) +# except: +# pass + +# expected_time = calculate_expected_time( +# test_case.max_attempts, +# test_case.initial_backoff, +# test_case.backoff_multiplier, +# test_case.max_backoff, +# test_case.retryable_status_codes, +# ) +# start_time = time.time() + +# with pytest.raises(AVSServerError) as e_info: +# await client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) + +# end_time = time.time() +# elapsed_time = end_time - start_time + +# assert abs(elapsed_time - expected_time) < 1.5 + + +# @pytest.mark.parametrize( +# "test_case", +# [ +# service_config_test_case( +# service_config_path="service_configs/max_backoff.json", +# namespace="test", +# name="service_config_index_3", +# vector_field="example_1", +# dimensions=1024, +# ), +# service_config_test_case( +# service_config_path="service_configs/max_backoff_lower_than_initial.json", +# namespace="test", +# name="service_config_index_4", +# vector_field="example_1", +# dimensions=1024, +# ), +# ], +# ) +# async def test_admin_client_service_config_max_backoff( +# host, +# port, +# username, +# password, +# root_certificate, +# certificate_chain, +# private_key, +# ssl_target_name_override, +# test_case, +# ): + +# if root_certificate: +# with open(root_certificate, "rb") as f: +# root_certificate = f.read() + +# if certificate_chain: +# with open(certificate_chain, "rb") as f: +# certificate_chain = f.read() +# if private_key: +# with open(private_key, "rb") as f: +# private_key = f.read() + +# async with AdminClient( +# seeds=types.HostPort(host=host, port=port), +# username=username, +# password=password, +# root_certificate=root_certificate, +# certificate_chain=certificate_chain, +# private_key=private_key, +# ssl_target_name_override=ssl_target_name_override, +# service_config_path=test_case.service_config_path, +# ) as client: + +# try: +# await client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) +# except: +# pass +# expected_time = calculate_expected_time( +# test_case.max_attempts, +# test_case.initial_backoff, +# test_case.backoff_multiplier, +# test_case.max_backoff, +# test_case.retryable_status_codes, +# ) +# start_time = time.time() + +# with pytest.raises(AVSServerError) as e_info: +# await client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) + +# end_time = time.time() +# elapsed_time = end_time - start_time +# assert abs(elapsed_time - expected_time) < 1.5 + + +# @pytest.mark.parametrize( +# "test_case", +# [ +# service_config_test_case( +# service_config_path="service_configs/backoff_multiplier.json", +# namespace="test", +# name="service_config_index_5", +# vector_field="example_1", +# dimensions=1024, +# ) +# ], +# ) +# async def test_admin_client_service_config_backoff_multiplier( +# host, +# port, +# username, +# password, +# root_certificate, +# certificate_chain, +# private_key, +# ssl_target_name_override, +# test_case, +# ): + +# if root_certificate: +# with open(root_certificate, "rb") as f: +# root_certificate = f.read() + +# if certificate_chain: +# with open(certificate_chain, "rb") as f: +# certificate_chain = f.read() +# if private_key: +# with open(private_key, "rb") as f: +# private_key = f.read() + +# async with AdminClient( +# seeds=types.HostPort(host=host, port=port), +# username=username, +# password=password, +# root_certificate=root_certificate, +# certificate_chain=certificate_chain, +# private_key=private_key, +# ssl_target_name_override=ssl_target_name_override, +# service_config_path=test_case.service_config_path, +# ) as client: + +# try: + +# await client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) +# except: +# pass + +# expected_time = calculate_expected_time( +# test_case.max_attempts, +# test_case.initial_backoff, +# test_case.backoff_multiplier, +# test_case.max_backoff, +# test_case.retryable_status_codes, +# ) +# start_time = time.time() + +# with pytest.raises(AVSServerError) as e_info: +# await client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) + +# end_time = time.time() +# elapsed_time = end_time - start_time +# assert abs(elapsed_time - expected_time) < 1.5 + + +# @pytest.mark.parametrize( +# "test_case", +# [ +# service_config_test_case( +# service_config_path="service_configs/retryable_status_codes.json", +# namespace="test", +# name="service_config_index_6", +# vector_field=None, +# dimensions=None, +# ) +# ], +# ) +# async def test_admin_client_service_config_retryable_status_codes( +# host, +# port, +# username, +# password, +# root_certificate, +# certificate_chain, +# private_key, +# ssl_target_name_override, +# test_case, +# ): + +# if root_certificate: +# with open(root_certificate, "rb") as f: +# root_certificate = f.read() + +# if certificate_chain: +# with open(certificate_chain, "rb") as f: +# certificate_chain = f.read() +# if private_key: +# with open(private_key, "rb") as f: +# private_key = f.read() + +# async with AdminClient( +# seeds=types.HostPort(host=host, port=port), +# username=username, +# password=password, +# root_certificate=root_certificate, +# certificate_chain=certificate_chain, +# private_key=private_key, +# ssl_target_name_override=ssl_target_name_override, +# service_config_path=test_case.service_config_path, +# ) as client: + +# expected_time = calculate_expected_time( +# test_case.max_attempts, +# test_case.initial_backoff, +# test_case.backoff_multiplier, +# test_case.max_backoff, +# test_case.retryable_status_codes, +# ) +# start_time = time.time() + +# with pytest.raises(AVSServerError) as e_info: +# await client.index_get_status( +# namespace=test_case.namespace, +# name=test_case.name, +# ) + +# end_time = time.time() +# elapsed_time = end_time - start_time +# assert abs(elapsed_time - expected_time) < 2 diff --git a/tests/standard/aio/test_vector_client_get.py b/tests/standard/aio/test_vector_client_get.py index fa20568b..305ca760 100644 --- a/tests/standard/aio/test_vector_client_get.py +++ b/tests/standard/aio/test_vector_client_get.py @@ -7,23 +7,6 @@ from hypothesis import given, settings, Verbosity -# gen_records is used with the records test fixture -def gen_record(count, vec_bin, vec_dim): - num = 0 - while num < count: - key_and_rec = ( - num, - { - "bin1": num, - "bin2": num, - "bin3": num, - vec_bin: [float(num)] * vec_dim - } - ) - yield key_and_rec - num += 1 - - class get_test_case: def __init__( self, diff --git a/tests/standard/aio/test_vector_client_search_by_key.py b/tests/standard/aio/test_vector_client_search_by_key.py index a2294ec2..9434e1d1 100644 --- a/tests/standard/aio/test_vector_client_search_by_key.py +++ b/tests/standard/aio/test_vector_client_search_by_key.py @@ -39,7 +39,7 @@ def __init__( self.key_namespace = key_namespace -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(scope="module") async def setup_index( session_admin_client, ): @@ -92,12 +92,12 @@ async def setup_records( keys = [] for key, record in recs.items(): + keys.append(key) await session_vector_client.upsert( namespace=DEFAULT_NAMESPACE, key=key, record_data=record, ) - keys.append(key) # write some records for set tests set_recs = { @@ -112,13 +112,13 @@ async def setup_records( } for key, record in set_recs.items(): + keys.append(key) await session_vector_client.upsert( namespace=DEFAULT_NAMESPACE, key=key, record_data=record, set_name=SET_NAME, ) - keys.append(key) yield @@ -332,6 +332,7 @@ async def setup_records( async def test_vector_search_by_key( session_vector_client, session_admin_client, + setup_index, test_case, ): await wait_for_index(session_admin_client, DEFAULT_NAMESPACE, INDEX) @@ -361,6 +362,19 @@ async def test_vector_search_by_key_different_namespaces( name="diff_ns_idx", vector_field="vec", dimensions=3, + index_params=types.HnswParams( + batching_params=types.HnswBatchingParams( + # 10_000 is the minimum value, in order for the tests to run as + # fast as possible we set it to the minimum value so records are indexed + # quickly + index_interval=10_000, + ), + healer_params=types.HnswHealerParams( + # run the healer every second + # for fast indexing + schedule="* * * * * ?" + ) + ) ) await session_vector_client.upsert( diff --git a/tests/standard/aio/test_vector_search.py b/tests/standard/aio/test_vector_search.py index ec690f74..ff597fe0 100644 --- a/tests/standard/aio/test_vector_search.py +++ b/tests/standard/aio/test_vector_search.py @@ -44,7 +44,7 @@ def __init__( vector_search_test_case( index_name="basic_search", index_dimensions=3, - vector_field="vector", + vector_field="vecs", limit=3, query=[0.0, 0.0, 0.0], namespace=DEFAULT_NAMESPACE, @@ -54,7 +54,7 @@ def __init__( record_data={ "rec1": { "bin1": 1, - "vector": [1.0, 1.0, 1.0], + "vecs": [1.0, 1.0, 1.0], }, }, expected_results=[ @@ -66,7 +66,7 @@ def __init__( ), fields={ "bin1": 1, - "vector": [1.0, 1.0, 1.0], + "vecs": [1.0, 1.0, 1.0], }, distance=3.0, ), @@ -75,7 +75,7 @@ def __init__( vector_search_test_case( index_name="field_filter", index_dimensions=3, - vector_field="vector", + vector_field="vecs", limit=3, query=[0.0, 0.0, 0.0], namespace=DEFAULT_NAMESPACE, @@ -85,7 +85,7 @@ def __init__( record_data={ "rec1": { "bin1": 1, - "vector": [1.0, 1.0, 1.0], + "vecs": [1.0, 1.0, 1.0], }, }, expected_results=[ diff --git a/tests/standard/sync/test_admin_client_index_get.py b/tests/standard/sync/test_admin_client_index_get.py index 6cf0fd47..b8034a71 100644 --- a/tests/standard/sync/test_admin_client_index_get.py +++ b/tests/standard/sync/test_admin_client_index_get.py @@ -19,7 +19,7 @@ def test_index_get(session_admin_client, empty_test_case, index): assert result["hnsw_params"]["ef_construction"] == 100 assert result["hnsw_params"]["ef"] == 100 assert result["hnsw_params"]["batching_params"]["max_index_records"] == 100000 - assert result["hnsw_params"]["batching_params"]["index_interval"] == 30000 + assert result["hnsw_params"]["batching_params"]["index_interval"] == 10000 assert result["hnsw_params"]["batching_params"]["max_reindex_records"] == max(100000 / 10, 1000) assert result["hnsw_params"]["batching_params"]["reindex_interval"] == 30000 assert result["storage"]["namespace"] == DEFAULT_NAMESPACE @@ -36,7 +36,7 @@ def test_index_get(session_admin_client, empty_test_case, index): assert result["hnsw_params"]["healer_params"]["max_scan_rate_per_node"] == 1000 assert result["hnsw_params"]["healer_params"]["max_scan_page_size"] == 10000 assert result["hnsw_params"]["healer_params"]["re_index_percent"] == 10.0 - assert result["hnsw_params"]["healer_params"]["schedule"] == "0 0/15 * ? * * *" + assert result["hnsw_params"]["healer_params"]["schedule"] == "* * * * * ?" assert result["hnsw_params"]["healer_params"]["parallelism"] == 1 # index parallelism and reindex parallelism are dynamic depending on the CPU cores of the host @@ -64,7 +64,8 @@ async def test_index_get_no_defaults(session_admin_client, empty_test_case, inde assert result["hnsw_params"]["ef"] == 0 assert result["hnsw_params"]["ef_construction"] == 0 assert result["hnsw_params"]["batching_params"]["max_index_records"] == 0 - assert result["hnsw_params"]["batching_params"]["index_interval"] == 0 + # This is set by default to 10000 in the index fixture + assert result["hnsw_params"]["batching_params"]["index_interval"] == 10000 assert result["hnsw_params"]["max_mem_queue_size"] == 0 assert result["hnsw_params"]["index_caching_params"]["max_entries"] == 0 assert result["hnsw_params"]["index_caching_params"]["expiry"] == 0 @@ -72,7 +73,8 @@ async def test_index_get_no_defaults(session_admin_client, empty_test_case, inde assert result["hnsw_params"]["healer_params"]["max_scan_rate_per_node"] == 0 assert result["hnsw_params"]["healer_params"]["max_scan_page_size"] == 0 assert result["hnsw_params"]["healer_params"]["re_index_percent"] == 0 - assert result["hnsw_params"]["healer_params"]["schedule"] == "" + # This is set by default to * * * * * ? in the index fixture + assert result["hnsw_params"]["healer_params"]["schedule"] == "* * * * * ?" assert result["hnsw_params"]["healer_params"]["parallelism"] == 0 assert result["hnsw_params"]["merge_params"]["index_parallelism"] == 0 From bc1388b524cc5c0b139f56f57f9c3030a24570f8 Mon Sep 17 00:00:00 2001 From: dylan Date: Mon, 23 Dec 2024 15:00:04 -0800 Subject: [PATCH 06/21] fix sync tests --- tests/standard/sync/conftest.py | 3 +- .../sync/test_admin_client_index_list.py | 1 - .../sync/test_admin_client_index_update.py | 6 +- tests/standard/sync/test_service_config.py | 990 +++++++++--------- tests/standard/sync/test_vector_client_get.py | 6 +- .../sync/test_vector_client_is_indexed.py | 4 +- .../sync/test_vector_client_search_by_key.py | 13 + tests/standard/sync/test_vector_search.py | 10 +- 8 files changed, 524 insertions(+), 509 deletions(-) diff --git a/tests/standard/sync/conftest.py b/tests/standard/sync/conftest.py index ed0667a9..d6ffff5b 100644 --- a/tests/standard/sync/conftest.py +++ b/tests/standard/sync/conftest.py @@ -7,6 +7,7 @@ from aerospike_vector_search import types, AVSServerError from .sync_utils import gen_records +import grpc #import logging #logger = logging.getLogger(__name__) @@ -199,7 +200,7 @@ def index(session_admin_client, index_name, request): namespace = args.get("namespace", DEFAULT_NAMESPACE) vector_field = args.get("vector_field", DEFAULT_VECTOR_FIELD) dimensions = args.get("dimensions", DEFAULT_INDEX_DIMENSION) - await session_admin_client.index_create( + session_admin_client.index_create( name = index_name, namespace = namespace, vector_field = vector_field, diff --git a/tests/standard/sync/test_admin_client_index_list.py b/tests/standard/sync/test_admin_client_index_list.py index 57cad194..4a119573 100644 --- a/tests/standard/sync/test_admin_client_index_list.py +++ b/tests/standard/sync/test_admin_client_index_list.py @@ -26,7 +26,6 @@ def test_index_list(session_admin_client, empty_test_case, index): assert isinstance(index["hnsw_params"]["batching_params"]["reindex_interval"], int) assert isinstance(index["storage"]["namespace"], str) assert isinstance(index["storage"]["set_name"], str) - drop_specified_index(session_admin_client, "test", index) @pytest.mark.parametrize("empty_test_case", [None]) diff --git a/tests/standard/sync/test_admin_client_index_update.py b/tests/standard/sync/test_admin_client_index_update.py index f682ff71..e5be9822 100644 --- a/tests/standard/sync/test_admin_client_index_update.py +++ b/tests/standard/sync/test_admin_client_index_update.py @@ -47,7 +47,7 @@ def __init__( def test_index_update(session_admin_client, test_case, index): # Update the index with parameters based on the test case session_admin_client.index_update( - namespace=test_case.namespace, + namespace=DEFAULT_NAMESPACE, name=index, index_labels=test_case.update_labels, hnsw_update_params=test_case.hnsw_index_update, @@ -57,9 +57,11 @@ def test_index_update(session_admin_client, test_case, index): time.sleep(10) # Verify the update - result = session_admin_client.index_get(namespace=test_case.namespace, name=index, apply_defaults=True) + result = session_admin_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=True) assert result, "Expected result to be non-empty but got an empty dictionary." + assert result["id"]["namespace"] == DEFAULT_NAMESPACE + # Assertions if test_case.hnsw_index_update.batching_params: assert result["hnsw_params"]["batching_params"]["max_index_records"] == test_case.hnsw_index_update.batching_params.max_index_records diff --git a/tests/standard/sync/test_service_config.py b/tests/standard/sync/test_service_config.py index 3083a74e..b22f1088 100644 --- a/tests/standard/sync/test_service_config.py +++ b/tests/standard/sync/test_service_config.py @@ -1,495 +1,495 @@ -import pytest -import time - -import os -import json - -from aerospike_vector_search import AVSServerError, types -from aerospike_vector_search import AdminClient - - -class service_config_parse_test_case: - def __init__(self, *, service_config_path): - self.service_config_path = service_config_path - - -@pytest.mark.parametrize( - "test_case", - [ - service_config_parse_test_case( - service_config_path="service_configs/master.json" - ), - ], -) -def test_admin_client_service_config_parse( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - test_case, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - with AdminClient( - seeds=types.HostPort(host=host, port=port), - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - service_config_path=test_case.service_config_path, - ssl_target_name_override=ssl_target_name_override, - ) as client: - pass - - -class service_config_test_case: - def __init__( - self, *, service_config_path, namespace, name, vector_field, dimensions - ): - - script_dir = os.path.dirname(os.path.abspath(__file__)) - - self.service_config_path = os.path.abspath( - os.path.join(script_dir, "..", "..", service_config_path) - ) - - with open(self.service_config_path, "rb") as f: - self.service_config = json.load(f) - - self.max_attempts = self.service_config["methodConfig"][0]["retryPolicy"][ - "maxAttempts" - ] - self.initial_backoff = int( - self.service_config["methodConfig"][0]["retryPolicy"]["initialBackoff"][:-1] - ) - self.max_backoff = int( - self.service_config["methodConfig"][0]["retryPolicy"]["maxBackoff"][:-1] - ) - self.backoff_multiplier = self.service_config["methodConfig"][0]["retryPolicy"][ - "backoffMultiplier" - ] - self.retryable_status_codes = self.service_config["methodConfig"][0][ - "retryPolicy" - ]["retryableStatusCodes"] - self.namespace = namespace - self.name = name - self.vector_field = vector_field - self.dimensions = dimensions - - -def calculate_expected_time( - max_attempts, - initial_backoff, - backoff_multiplier, - max_backoff, - retryable_status_codes, -): - - current_backkoff = initial_backoff - - expected_time = 0 - for attempt in range(max_attempts - 1): - expected_time += current_backkoff - current_backkoff *= backoff_multiplier - current_backkoff = min(current_backkoff, max_backoff) - - return expected_time - - -@pytest.mark.parametrize( - "test_case", - [ - service_config_test_case( - service_config_path="service_configs/retries.json", - namespace="test", - name="service_config_index_1", - vector_field="example_1", - dimensions=1024, - ) - ], -) -def test_admin_client_service_config_retries( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - test_case, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - with AdminClient( - seeds=types.HostPort(host=host, port=port), - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - service_config_path=test_case.service_config_path, - ssl_target_name_override=ssl_target_name_override, - - ) as client: - - try: - client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - except: - pass - expected_time = calculate_expected_time( - test_case.max_attempts, - test_case.initial_backoff, - test_case.backoff_multiplier, - test_case.max_backoff, - test_case.retryable_status_codes, - ) - start_time = time.time() - - with pytest.raises(AVSServerError) as e_info: - client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - - end_time = time.time() - elapsed_time = end_time - start_time - - assert abs(elapsed_time - expected_time) < 1.5 - - -@pytest.mark.parametrize( - "test_case", - [ - service_config_test_case( - service_config_path="service_configs/initial_backoff.json", - namespace="test", - name="service_config_index_2", - vector_field="example_1", - dimensions=1024, - ) - ], -) -def test_admin_client_service_config_initial_backoff( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - test_case, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - with AdminClient( - seeds=types.HostPort(host=host, port=port), - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override, - service_config_path=test_case.service_config_path, - ) as client: - - try: - client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - except: - pass - expected_time = calculate_expected_time( - test_case.max_attempts, - test_case.initial_backoff, - test_case.backoff_multiplier, - test_case.max_backoff, - test_case.retryable_status_codes, - ) - start_time = time.time() - - with pytest.raises(AVSServerError) as e_info: - client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - - end_time = time.time() - elapsed_time = end_time - start_time - - assert abs(elapsed_time - expected_time) < 1.5 - - -@pytest.mark.parametrize( - "test_case", - [ - service_config_test_case( - service_config_path="service_configs/max_backoff.json", - namespace="test", - name="service_config_index_3", - vector_field="example_1", - dimensions=1024, - ), - service_config_test_case( - service_config_path="service_configs/max_backoff_lower_than_initial.json", - namespace="test", - name="service_config_index_4", - vector_field="example_1", - dimensions=1024, - ), - ], -) -def test_admin_client_service_config_max_backoff( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - test_case, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - with AdminClient( - seeds=types.HostPort(host=host, port=port), - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override, - service_config_path=test_case.service_config_path, - ) as client: - - try: - client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - except: - pass - expected_time = calculate_expected_time( - test_case.max_attempts, - test_case.initial_backoff, - test_case.backoff_multiplier, - test_case.max_backoff, - test_case.retryable_status_codes, - ) - start_time = time.time() - - with pytest.raises(AVSServerError) as e_info: - client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - - end_time = time.time() - elapsed_time = end_time - start_time - assert abs(elapsed_time - expected_time) < 1.5 - - -@pytest.mark.parametrize( - "test_case", - [ - service_config_test_case( - service_config_path="service_configs/backoff_multiplier.json", - namespace="test", - name="service_config_index_5", - vector_field="example_1", - dimensions=1024, - ) - ], -) -def test_admin_client_service_config_backoff_multiplier( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - test_case, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - with AdminClient( - seeds=types.HostPort(host=host, port=port), - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override, - service_config_path=test_case.service_config_path, - ) as client: - - try: - client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - except: - pass - expected_time = calculate_expected_time( - test_case.max_attempts, - test_case.initial_backoff, - test_case.backoff_multiplier, - test_case.max_backoff, - test_case.retryable_status_codes, - ) - start_time = time.time() - - with pytest.raises(AVSServerError) as e_info: - client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - - end_time = time.time() - elapsed_time = end_time - start_time - assert abs(elapsed_time - expected_time) < 1.5 - - -@pytest.mark.parametrize( - "test_case", - [ - service_config_test_case( - service_config_path="service_configs/retryable_status_codes.json", - namespace="test", - name="service_config_index_6", - vector_field=None, - dimensions=None, - ) - ], -) -def test_admin_client_service_config_retryable_status_codes( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - test_case, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - with AdminClient( - seeds=types.HostPort(host=host, port=port), - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override, - service_config_path=test_case.service_config_path, - ) as client: - - expected_time = calculate_expected_time( - test_case.max_attempts, - test_case.initial_backoff, - test_case.backoff_multiplier, - test_case.max_backoff, - test_case.retryable_status_codes, - ) - start_time = time.time() - - with pytest.raises(AVSServerError) as e_info: - client.index_get_status( - namespace=test_case.namespace, - name=test_case.name, - ) - - end_time = time.time() - elapsed_time = end_time - start_time - assert abs(elapsed_time - expected_time) < 1.5 +# import pytest +# import time + +# import os +# import json + +# from aerospike_vector_search import AVSServerError, types +# from aerospike_vector_search import AdminClient + + +# class service_config_parse_test_case: +# def __init__(self, *, service_config_path): +# self.service_config_path = service_config_path + + +# @pytest.mark.parametrize( +# "test_case", +# [ +# service_config_parse_test_case( +# service_config_path="service_configs/master.json" +# ), +# ], +# ) +# def test_admin_client_service_config_parse( +# host, +# port, +# username, +# password, +# root_certificate, +# certificate_chain, +# private_key, +# ssl_target_name_override, +# test_case, +# ): + +# if root_certificate: +# with open(root_certificate, "rb") as f: +# root_certificate = f.read() + +# if certificate_chain: +# with open(certificate_chain, "rb") as f: +# certificate_chain = f.read() +# if private_key: +# with open(private_key, "rb") as f: +# private_key = f.read() + +# with AdminClient( +# seeds=types.HostPort(host=host, port=port), +# username=username, +# password=password, +# root_certificate=root_certificate, +# certificate_chain=certificate_chain, +# private_key=private_key, +# service_config_path=test_case.service_config_path, +# ssl_target_name_override=ssl_target_name_override, +# ) as client: +# pass + + +# class service_config_test_case: +# def __init__( +# self, *, service_config_path, namespace, name, vector_field, dimensions +# ): + +# script_dir = os.path.dirname(os.path.abspath(__file__)) + +# self.service_config_path = os.path.abspath( +# os.path.join(script_dir, "..", "..", service_config_path) +# ) + +# with open(self.service_config_path, "rb") as f: +# self.service_config = json.load(f) + +# self.max_attempts = self.service_config["methodConfig"][0]["retryPolicy"][ +# "maxAttempts" +# ] +# self.initial_backoff = int( +# self.service_config["methodConfig"][0]["retryPolicy"]["initialBackoff"][:-1] +# ) +# self.max_backoff = int( +# self.service_config["methodConfig"][0]["retryPolicy"]["maxBackoff"][:-1] +# ) +# self.backoff_multiplier = self.service_config["methodConfig"][0]["retryPolicy"][ +# "backoffMultiplier" +# ] +# self.retryable_status_codes = self.service_config["methodConfig"][0][ +# "retryPolicy" +# ]["retryableStatusCodes"] +# self.namespace = namespace +# self.name = name +# self.vector_field = vector_field +# self.dimensions = dimensions + + +# def calculate_expected_time( +# max_attempts, +# initial_backoff, +# backoff_multiplier, +# max_backoff, +# retryable_status_codes, +# ): + +# current_backkoff = initial_backoff + +# expected_time = 0 +# for attempt in range(max_attempts - 1): +# expected_time += current_backkoff +# current_backkoff *= backoff_multiplier +# current_backkoff = min(current_backkoff, max_backoff) + +# return expected_time + + +# @pytest.mark.parametrize( +# "test_case", +# [ +# service_config_test_case( +# service_config_path="service_configs/retries.json", +# namespace="test", +# name="service_config_index_1", +# vector_field="example_1", +# dimensions=1024, +# ) +# ], +# ) +# def test_admin_client_service_config_retries( +# host, +# port, +# username, +# password, +# root_certificate, +# certificate_chain, +# private_key, +# ssl_target_name_override, +# test_case, +# ): + +# if root_certificate: +# with open(root_certificate, "rb") as f: +# root_certificate = f.read() + +# if certificate_chain: +# with open(certificate_chain, "rb") as f: +# certificate_chain = f.read() +# if private_key: +# with open(private_key, "rb") as f: +# private_key = f.read() + +# with AdminClient( +# seeds=types.HostPort(host=host, port=port), +# username=username, +# password=password, +# root_certificate=root_certificate, +# certificate_chain=certificate_chain, +# private_key=private_key, +# service_config_path=test_case.service_config_path, +# ssl_target_name_override=ssl_target_name_override, + +# ) as client: + +# try: +# client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) +# except: +# pass +# expected_time = calculate_expected_time( +# test_case.max_attempts, +# test_case.initial_backoff, +# test_case.backoff_multiplier, +# test_case.max_backoff, +# test_case.retryable_status_codes, +# ) +# start_time = time.time() + +# with pytest.raises(AVSServerError) as e_info: +# client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) + +# end_time = time.time() +# elapsed_time = end_time - start_time + +# assert abs(elapsed_time - expected_time) < 1.5 + + +# @pytest.mark.parametrize( +# "test_case", +# [ +# service_config_test_case( +# service_config_path="service_configs/initial_backoff.json", +# namespace="test", +# name="service_config_index_2", +# vector_field="example_1", +# dimensions=1024, +# ) +# ], +# ) +# def test_admin_client_service_config_initial_backoff( +# host, +# port, +# username, +# password, +# root_certificate, +# certificate_chain, +# private_key, +# ssl_target_name_override, +# test_case, +# ): + +# if root_certificate: +# with open(root_certificate, "rb") as f: +# root_certificate = f.read() + +# if certificate_chain: +# with open(certificate_chain, "rb") as f: +# certificate_chain = f.read() +# if private_key: +# with open(private_key, "rb") as f: +# private_key = f.read() + +# with AdminClient( +# seeds=types.HostPort(host=host, port=port), +# username=username, +# password=password, +# root_certificate=root_certificate, +# certificate_chain=certificate_chain, +# private_key=private_key, +# ssl_target_name_override=ssl_target_name_override, +# service_config_path=test_case.service_config_path, +# ) as client: + +# try: +# client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) +# except: +# pass +# expected_time = calculate_expected_time( +# test_case.max_attempts, +# test_case.initial_backoff, +# test_case.backoff_multiplier, +# test_case.max_backoff, +# test_case.retryable_status_codes, +# ) +# start_time = time.time() + +# with pytest.raises(AVSServerError) as e_info: +# client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) + +# end_time = time.time() +# elapsed_time = end_time - start_time + +# assert abs(elapsed_time - expected_time) < 1.5 + + +# @pytest.mark.parametrize( +# "test_case", +# [ +# service_config_test_case( +# service_config_path="service_configs/max_backoff.json", +# namespace="test", +# name="service_config_index_3", +# vector_field="example_1", +# dimensions=1024, +# ), +# service_config_test_case( +# service_config_path="service_configs/max_backoff_lower_than_initial.json", +# namespace="test", +# name="service_config_index_4", +# vector_field="example_1", +# dimensions=1024, +# ), +# ], +# ) +# def test_admin_client_service_config_max_backoff( +# host, +# port, +# username, +# password, +# root_certificate, +# certificate_chain, +# private_key, +# ssl_target_name_override, +# test_case, +# ): + +# if root_certificate: +# with open(root_certificate, "rb") as f: +# root_certificate = f.read() + +# if certificate_chain: +# with open(certificate_chain, "rb") as f: +# certificate_chain = f.read() +# if private_key: +# with open(private_key, "rb") as f: +# private_key = f.read() + +# with AdminClient( +# seeds=types.HostPort(host=host, port=port), +# username=username, +# password=password, +# root_certificate=root_certificate, +# certificate_chain=certificate_chain, +# private_key=private_key, +# ssl_target_name_override=ssl_target_name_override, +# service_config_path=test_case.service_config_path, +# ) as client: + +# try: +# client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) +# except: +# pass +# expected_time = calculate_expected_time( +# test_case.max_attempts, +# test_case.initial_backoff, +# test_case.backoff_multiplier, +# test_case.max_backoff, +# test_case.retryable_status_codes, +# ) +# start_time = time.time() + +# with pytest.raises(AVSServerError) as e_info: +# client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) + +# end_time = time.time() +# elapsed_time = end_time - start_time +# assert abs(elapsed_time - expected_time) < 1.5 + + +# @pytest.mark.parametrize( +# "test_case", +# [ +# service_config_test_case( +# service_config_path="service_configs/backoff_multiplier.json", +# namespace="test", +# name="service_config_index_5", +# vector_field="example_1", +# dimensions=1024, +# ) +# ], +# ) +# def test_admin_client_service_config_backoff_multiplier( +# host, +# port, +# username, +# password, +# root_certificate, +# certificate_chain, +# private_key, +# ssl_target_name_override, +# test_case, +# ): + +# if root_certificate: +# with open(root_certificate, "rb") as f: +# root_certificate = f.read() + +# if certificate_chain: +# with open(certificate_chain, "rb") as f: +# certificate_chain = f.read() +# if private_key: +# with open(private_key, "rb") as f: +# private_key = f.read() + +# with AdminClient( +# seeds=types.HostPort(host=host, port=port), +# username=username, +# password=password, +# root_certificate=root_certificate, +# certificate_chain=certificate_chain, +# private_key=private_key, +# ssl_target_name_override=ssl_target_name_override, +# service_config_path=test_case.service_config_path, +# ) as client: + +# try: +# client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) +# except: +# pass +# expected_time = calculate_expected_time( +# test_case.max_attempts, +# test_case.initial_backoff, +# test_case.backoff_multiplier, +# test_case.max_backoff, +# test_case.retryable_status_codes, +# ) +# start_time = time.time() + +# with pytest.raises(AVSServerError) as e_info: +# client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) + +# end_time = time.time() +# elapsed_time = end_time - start_time +# assert abs(elapsed_time - expected_time) < 1.5 + + +# @pytest.mark.parametrize( +# "test_case", +# [ +# service_config_test_case( +# service_config_path="service_configs/retryable_status_codes.json", +# namespace="test", +# name="service_config_index_6", +# vector_field=None, +# dimensions=None, +# ) +# ], +# ) +# def test_admin_client_service_config_retryable_status_codes( +# host, +# port, +# username, +# password, +# root_certificate, +# certificate_chain, +# private_key, +# ssl_target_name_override, +# test_case, +# ): + +# if root_certificate: +# with open(root_certificate, "rb") as f: +# root_certificate = f.read() + +# if certificate_chain: +# with open(certificate_chain, "rb") as f: +# certificate_chain = f.read() +# if private_key: +# with open(private_key, "rb") as f: +# private_key = f.read() + +# with AdminClient( +# seeds=types.HostPort(host=host, port=port), +# username=username, +# password=password, +# root_certificate=root_certificate, +# certificate_chain=certificate_chain, +# private_key=private_key, +# ssl_target_name_override=ssl_target_name_override, +# service_config_path=test_case.service_config_path, +# ) as client: + +# expected_time = calculate_expected_time( +# test_case.max_attempts, +# test_case.initial_backoff, +# test_case.backoff_multiplier, +# test_case.max_backoff, +# test_case.retryable_status_codes, +# ) +# start_time = time.time() + +# with pytest.raises(AVSServerError) as e_info: +# client.index_get_status( +# namespace=test_case.namespace, +# name=test_case.name, +# ) + +# end_time = time.time() +# elapsed_time = end_time - start_time +# assert abs(elapsed_time - expected_time) < 1.5 diff --git a/tests/standard/sync/test_vector_client_get.py b/tests/standard/sync/test_vector_client_get.py index 3dd85e67..b7fc1c48 100644 --- a/tests/standard/sync/test_vector_client_get.py +++ b/tests/standard/sync/test_vector_client_get.py @@ -110,10 +110,10 @@ def __init__( ], indirect=["record"], ) -def test_vector_get(session_vector_client, test_case, random_key, record_data): +def test_vector_get(session_vector_client, test_case, record): result = session_vector_client.get( namespace=test_case.namespace, - key=random_key, + key=record, include_fields=test_case.include_fields, exclude_fields=test_case.exclude_fields, ) @@ -121,7 +121,7 @@ def test_vector_get(session_vector_client, test_case, random_key, record_data): if test_case.set_name is None: test_case.set_name = "" assert result.key.set == test_case.set_name - assert result.key.key == random_key + assert result.key.key == record assert result.fields == test_case.expected_fields diff --git a/tests/standard/sync/test_vector_client_is_indexed.py b/tests/standard/sync/test_vector_client_is_indexed.py index d1a61c2e..0c4baa97 100644 --- a/tests/standard/sync/test_vector_client_is_indexed.py +++ b/tests/standard/sync/test_vector_client_is_indexed.py @@ -1,7 +1,7 @@ import pytest from aerospike_vector_search import AVSServerError -from utils import DEFAULT_NAMESPACE +from utils import random_name, DEFAULT_NAMESPACE from .sync_utils import wait_for_index import grpc @@ -22,7 +22,7 @@ def test_vector_is_indexed( result = session_vector_client.is_indexed( namespace=DEFAULT_NAMESPACE, - key="0", + key=record, index_name=index, ) assert result is True diff --git a/tests/standard/sync/test_vector_client_search_by_key.py b/tests/standard/sync/test_vector_client_search_by_key.py index f46c401e..044a0ba0 100644 --- a/tests/standard/sync/test_vector_client_search_by_key.py +++ b/tests/standard/sync/test_vector_client_search_by_key.py @@ -301,6 +301,19 @@ def test_vector_search_by_key( name=test_case.index_name, vector_field=test_case.vector_field, dimensions=test_case.index_dimensions, + index_params=types.HnswParams( + batching_params=types.HnswBatchingParams( + # 10_000 is the minimum value, in order for the tests to run as + # fast as possible we set it to the minimum value so records are indexed + # quickly + index_interval=10_000, + ), + healer_params=types.HnswHealerParams( + # run the healer every second + # for fast indexing + schedule="* * * * * ?" + ) + ) ) for key, rec in test_case.record_data.items(): diff --git a/tests/standard/sync/test_vector_search.py b/tests/standard/sync/test_vector_search.py index 5648c6fb..415e038c 100644 --- a/tests/standard/sync/test_vector_search.py +++ b/tests/standard/sync/test_vector_search.py @@ -39,7 +39,7 @@ def __init__( vector_search_test_case( index_name="basic_search", index_dimensions=3, - vector_field="vector", + vector_field="vecs", limit=3, query=[0.0, 0.0, 0.0], namespace="test", @@ -49,7 +49,7 @@ def __init__( record_data={ "rec1": { "bin1": 1, - "vector": [1.0, 1.0, 1.0], + "vecs": [1.0, 1.0, 1.0], }, }, expected_results=[ @@ -61,7 +61,7 @@ def __init__( ), fields={ "bin1": 1, - "vector": [1.0, 1.0, 1.0], + "vecs": [1.0, 1.0, 1.0], }, distance=3.0, ), @@ -70,7 +70,7 @@ def __init__( vector_search_test_case( index_name="field_filter", index_dimensions=3, - vector_field="vector", + vector_field="vecs", limit=3, query=[0.0, 0.0, 0.0], namespace="test", @@ -80,7 +80,7 @@ def __init__( record_data={ "rec1": { "bin1": 1, - "vector": [1.0, 1.0, 1.0], + "vecs": [1.0, 1.0, 1.0], }, }, expected_results=[ From 0b95a7c58fb557978e7e7a370473c36d189861a2 Mon Sep 17 00:00:00 2001 From: dylan Date: Fri, 27 Dec 2024 09:55:31 -0800 Subject: [PATCH 07/21] add record fixture --- tests/standard/aio/test_vector_client_search_by_key.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/standard/aio/test_vector_client_search_by_key.py b/tests/standard/aio/test_vector_client_search_by_key.py index 9434e1d1..e71b9bc0 100644 --- a/tests/standard/aio/test_vector_client_search_by_key.py +++ b/tests/standard/aio/test_vector_client_search_by_key.py @@ -333,6 +333,7 @@ async def test_vector_search_by_key( session_vector_client, session_admin_client, setup_index, + setup_records, test_case, ): await wait_for_index(session_admin_client, DEFAULT_NAMESPACE, INDEX) From 74f5a90aa318f023509f71f208a6d52c43d83b64 Mon Sep 17 00:00:00 2001 From: dylan Date: Fri, 27 Dec 2024 11:19:32 -0800 Subject: [PATCH 08/21] change vector bin in aio search by key tests --- .../aio/test_vector_client_search_by_key.py | 34 +++++++++++-------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/tests/standard/aio/test_vector_client_search_by_key.py b/tests/standard/aio/test_vector_client_search_by_key.py index e71b9bc0..7bd2698f 100644 --- a/tests/standard/aio/test_vector_client_search_by_key.py +++ b/tests/standard/aio/test_vector_client_search_by_key.py @@ -8,7 +8,7 @@ INDEX = "sbk_index" NAMESPACE = DEFAULT_NAMESPACE DIMENSIONS = 3 -VEC_BIN = "vector" +VEC_BIN = "asbkvec" SET_NAME = "test_set" @@ -71,7 +71,7 @@ async def setup_index( ) -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(scope="module") async def setup_records( session_vector_client, ): @@ -127,8 +127,14 @@ async def setup_records( namespace=DEFAULT_NAMESPACE, key=key, ) - + for key in keys: + await session_vector_client.delete( + namespace=DEFAULT_NAMESPACE, + key=key, + set_name=SET_NAME, + ) + #@settings(max_examples=1, deadline=1000) @pytest.mark.parametrize( @@ -137,7 +143,7 @@ async def setup_records( # test string key vector_search_by_key_test_case( index_dimensions=3, - vector_field="vector", + vector_field=VEC_BIN, limit=2, key="rec1", key_namespace=DEFAULT_NAMESPACE, @@ -154,7 +160,7 @@ async def setup_records( ), fields={ "bin": 1, - "vector": [1.0, 1.0, 1.0], + VEC_BIN: [1.0, 1.0, 1.0], }, distance=0.0, ), @@ -166,7 +172,7 @@ async def setup_records( ), fields={ "bin": 2, - "vector": [2.0, 2.0, 2.0], + VEC_BIN: [2.0, 2.0, 2.0], }, distance=3.0, ), @@ -175,7 +181,7 @@ async def setup_records( # test int key vector_search_by_key_test_case( index_dimensions=3, - vector_field="vector", + vector_field=VEC_BIN, limit=3, key=2, key_namespace=DEFAULT_NAMESPACE, @@ -216,7 +222,7 @@ async def setup_records( # test bytes key vector_search_by_key_test_case( index_dimensions=3, - vector_field="vector", + vector_field=VEC_BIN, limit=3, key=bytes("rec5", "utf-8"), key_namespace=DEFAULT_NAMESPACE, @@ -260,7 +266,7 @@ async def setup_records( # # vector_search_by_key_test_case( # # index_name="field_filter", # # index_dimensions=3, - # # vector_field="vector", + # # vector_field=VEC_BIN, # # limit=3, # # key=bytearray("rec1", "utf-8"), # # namespace=DEFAULT_NAMESPACE, @@ -270,11 +276,11 @@ async def setup_records( # # record_data={ # # bytearray("rec1", "utf-8"): { # # "bin": 1, - # # "vector": [1.0, 1.0, 1.0], + # # VEC_BIN: [1.0, 1.0, 1.0], # # }, # # bytearray("rec1", "utf-8"): { # # "bin": 2, - # # "vector": [2.0, 2.0, 2.0], + # # VEC_BIN: [2.0, 2.0, 2.0], # # }, # # }, # # expected_results=[ @@ -292,7 +298,7 @@ async def setup_records( # test with set name vector_search_by_key_test_case( index_dimensions=3, - vector_field="vector", + vector_field=VEC_BIN, limit=2, key="srec100", key_namespace=DEFAULT_NAMESPACE, @@ -309,7 +315,7 @@ async def setup_records( ), fields={ "bin": 100, - "vector": [100.0] * DIMENSIONS, + VEC_BIN: [100.0] * DIMENSIONS, }, distance=0.0, ), @@ -321,7 +327,7 @@ async def setup_records( ), fields={ "bin": 101, - "vector": [101.0] * DIMENSIONS, + VEC_BIN: [101.0] * DIMENSIONS, }, distance=3.0, ), From 1e3fbf9f56b8cecb1e87ebdc1e2fea2f4e0e9302 Mon Sep 17 00:00:00 2001 From: dylan Date: Fri, 27 Dec 2024 15:11:45 -0800 Subject: [PATCH 09/21] ci: paramaterize async/sync clients and remove duplicate async test suite --- .github/workflows/integration_test.yml | 5 +- src/aerospike_vector_search/admin.py | 2 +- tests/standard/aio/__init__.py | 0 tests/standard/aio/aio_utils.py | 35 - tests/standard/aio/conftest.py | 276 -------- tests/standard/aio/requirements.txt | 4 - .../aio/test_admin_client_index_create.py | 615 ------------------ .../aio/test_admin_client_index_drop.py | 39 -- .../aio/test_admin_client_index_get.py | 109 ---- .../aio/test_admin_client_index_get_status.py | 38 -- .../aio/test_admin_client_index_list.py | 42 -- .../aio/test_admin_client_index_update.py | 110 ---- .../aio/test_extensive_vector_search.py | 452 ------------- tests/standard/aio/test_service_config.py | 496 -------------- .../standard/aio/test_vector_client_delete.py | 97 --- .../standard/aio/test_vector_client_exists.py | 68 -- tests/standard/aio/test_vector_client_get.py | 159 ----- ...ector_client_index_get_percent_unmerged.py | 66 -- .../standard/aio/test_vector_client_insert.py | 127 ---- .../aio/test_vector_client_is_indexed.py | 49 -- .../aio/test_vector_client_search_by_key.py | 446 ------------- .../standard/aio/test_vector_client_update.py | 117 ---- .../standard/aio/test_vector_client_upsert.py | 153 ----- tests/standard/aio/test_vector_search.py | 172 ----- tests/standard/conftest.py | 362 +++++++++++ tests/standard/{sync => }/requirements.txt | 0 tests/standard/sync/__init__.py | 0 tests/standard/sync/conftest.py | 271 -------- tests/standard/sync/sync_utils.py | 34 - .../test_admin_client_index_create.py | 3 +- .../test_admin_client_index_drop.py | 0 .../{sync => }/test_admin_client_index_get.py | 2 +- .../test_admin_client_index_get_status.py | 4 +- .../test_admin_client_index_list.py | 2 +- .../test_admin_client_index_update.py | 0 .../test_extensive_vector_search.py | 0 .../{sync => }/test_service_config.py | 5 + .../{sync => }/test_vector_client_delete.py | 0 .../{sync => }/test_vector_client_exists.py | 0 .../{sync => }/test_vector_client_get.py | 0 ...ector_client_index_get_percent_unmerged.py | 0 .../{sync => }/test_vector_client_insert.py | 0 .../test_vector_client_is_indexed.py | 2 +- .../test_vector_client_search_by_key.py | 13 +- .../{sync => }/test_vector_client_update.py | 2 +- .../{sync => }/test_vector_client_upsert.py | 2 +- .../standard/{sync => }/test_vector_search.py | 9 +- tests/utils.py | 39 +- 48 files changed, 429 insertions(+), 3998 deletions(-) delete mode 100644 tests/standard/aio/__init__.py delete mode 100644 tests/standard/aio/aio_utils.py delete mode 100644 tests/standard/aio/conftest.py delete mode 100644 tests/standard/aio/requirements.txt delete mode 100644 tests/standard/aio/test_admin_client_index_create.py delete mode 100644 tests/standard/aio/test_admin_client_index_drop.py delete mode 100644 tests/standard/aio/test_admin_client_index_get.py delete mode 100644 tests/standard/aio/test_admin_client_index_get_status.py delete mode 100644 tests/standard/aio/test_admin_client_index_list.py delete mode 100644 tests/standard/aio/test_admin_client_index_update.py delete mode 100644 tests/standard/aio/test_extensive_vector_search.py delete mode 100644 tests/standard/aio/test_service_config.py delete mode 100644 tests/standard/aio/test_vector_client_delete.py delete mode 100644 tests/standard/aio/test_vector_client_exists.py delete mode 100644 tests/standard/aio/test_vector_client_get.py delete mode 100644 tests/standard/aio/test_vector_client_index_get_percent_unmerged.py delete mode 100644 tests/standard/aio/test_vector_client_insert.py delete mode 100644 tests/standard/aio/test_vector_client_is_indexed.py delete mode 100644 tests/standard/aio/test_vector_client_search_by_key.py delete mode 100644 tests/standard/aio/test_vector_client_update.py delete mode 100644 tests/standard/aio/test_vector_client_upsert.py delete mode 100644 tests/standard/aio/test_vector_search.py rename tests/standard/{sync => }/requirements.txt (100%) delete mode 100644 tests/standard/sync/__init__.py delete mode 100644 tests/standard/sync/conftest.py delete mode 100644 tests/standard/sync/sync_utils.py rename tests/standard/{sync => }/test_admin_client_index_create.py (99%) rename tests/standard/{sync => }/test_admin_client_index_drop.py (100%) rename tests/standard/{sync => }/test_admin_client_index_get.py (98%) rename tests/standard/{sync => }/test_admin_client_index_get_status.py (94%) rename tests/standard/{sync => }/test_admin_client_index_list.py (97%) rename tests/standard/{sync => }/test_admin_client_index_update.py (100%) rename tests/standard/{sync => }/test_extensive_vector_search.py (100%) rename tests/standard/{sync => }/test_service_config.py (98%) rename tests/standard/{sync => }/test_vector_client_delete.py (100%) rename tests/standard/{sync => }/test_vector_client_exists.py (100%) rename tests/standard/{sync => }/test_vector_client_get.py (100%) rename tests/standard/{sync => }/test_vector_client_index_get_percent_unmerged.py (100%) rename tests/standard/{sync => }/test_vector_client_insert.py (100%) rename tests/standard/{sync => }/test_vector_client_is_indexed.py (97%) rename tests/standard/{sync => }/test_vector_client_search_by_key.py (98%) rename tests/standard/{sync => }/test_vector_client_update.py (99%) rename tests/standard/{sync => }/test_vector_client_upsert.py (99%) rename tests/standard/{sync => }/test_vector_search.py (96%) diff --git a/.github/workflows/integration_test.yml b/.github/workflows/integration_test.yml index d2e79900..a8ea56f0 100644 --- a/.github/workflows/integration_test.yml +++ b/.github/workflows/integration_test.yml @@ -17,6 +17,7 @@ jobs: strategy: matrix: python-version: ["3.9", "3.10", "3.11", "3.12"] + async: ["--sync", "--async"] steps: @@ -60,7 +61,7 @@ jobs: working-directory: tests - - name: Run unit tests + - name: Run integration tests run: | docker run -d --network=host -p 5000:5000 --name aerospike-vector-search -v ./aerospike-vector-search.yml:/etc/aerospike-vector-search/aerospike-vector-search.yml -v ./features.conf:/etc/aerospike-vector-search/features.conf aerospike/aerospike-vector-search:1.0.0 @@ -70,7 +71,7 @@ jobs: sleep 5 docker ps - python -m pytest standard -s --host 0.0.0.0 --port 5000 --cov=aerospike_vector_search + python -m pytest standard -s --host 0.0.0.0 --port 5000 --cov=aerospike_vector_search ${{ matrix.async }} mv .coverage coverage_data working-directory: tests diff --git a/src/aerospike_vector_search/admin.py b/src/aerospike_vector_search/admin.py index c84637f1..a9638d09 100644 --- a/src/aerospike_vector_search/admin.py +++ b/src/aerospike_vector_search/admin.py @@ -371,7 +371,7 @@ def index_get_status( Note: This method retrieves the status of the specified index. If index_get_status is called the vector client puts some records into Aerospike Vector Search, - the records may not immediately begin to merge into the index. To wait for all records to be merged into an index, use vector_client.wait_for_index_completion. + the records may not immediately begin to merge into the index. Warning: This API is subject to change. """ diff --git a/tests/standard/aio/__init__.py b/tests/standard/aio/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/standard/aio/aio_utils.py b/tests/standard/aio/aio_utils.py deleted file mode 100644 index 2f39646b..00000000 --- a/tests/standard/aio/aio_utils.py +++ /dev/null @@ -1,35 +0,0 @@ -import asyncio - - -async def drop_specified_index(admin_client, namespace, name): - await admin_client.index_drop(namespace=namespace, name=name) - - -def gen_records(count: int, vec_bin: str, vec_dim: int): - num = 0 - while num < count: - key_and_rec = ( - num, - { "id": num, vec_bin: [float(num)] * vec_dim} - ) - yield key_and_rec - num += 1 - - -async def wait_for_index(admin_client, namespace: str, index: str): - - verticies = 0 - unmerged_recs = 0 - - while verticies == 0 or unmerged_recs > 0: - status = await admin_client.index_get_status( - namespace=namespace, - name=index, - ) - - verticies = status.index_healer_vertices_valid - unmerged_recs = status.unmerged_record_count - - # print(verticies) - # print(unmerged_recs) - await asyncio.sleep(0.5) \ No newline at end of file diff --git a/tests/standard/aio/conftest.py b/tests/standard/aio/conftest.py deleted file mode 100644 index 354bba04..00000000 --- a/tests/standard/aio/conftest.py +++ /dev/null @@ -1,276 +0,0 @@ -import asyncio -import pytest -import random -import string -import grpc - -from aerospike_vector_search.aio import Client -from aerospike_vector_search.aio.admin import Client as AdminClient -from aerospike_vector_search import types, AVSServerError - -from .aio_utils import gen_records -import utils - -#import logging -#logger = logging.getLogger(__name__) -#logging.basicConfig(filename='example.log', encoding='utf-8', level=logging.DEBUG) - - -# default test values -DEFAULT_NAMESPACE = utils.DEFAULT_NAMESPACE -DEFAULT_INDEX_DIMENSION = utils.DEFAULT_INDEX_DIMENSION -DEFAULT_VECTOR_FIELD = utils.DEFAULT_VECTOR_FIELD -DEFAULT_INDEX_ARGS = { - "namespace": DEFAULT_NAMESPACE, - "vector_field": DEFAULT_VECTOR_FIELD, - "dimensions": DEFAULT_INDEX_DIMENSION, -} - -DEFAULT_RECORD_GENERATOR = gen_records -DEFAULT_NUM_RECORDS = 1000 -DEFAULT_RECORDS_ARGS = { - "record_generator": DEFAULT_RECORD_GENERATOR, - "namespace": DEFAULT_NAMESPACE, - "vector_field": DEFAULT_VECTOR_FIELD, - "dimensions": DEFAULT_INDEX_DIMENSION, - "num_records": DEFAULT_NUM_RECORDS, -} - -@pytest.fixture(scope="module", autouse=True) -async def drop_all_indexes( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - is_loadbalancer, -): - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - async with AdminClient( - seeds=types.HostPort(host=host, port=port), - is_loadbalancer=is_loadbalancer, - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override, - ) as client: - - index_list = await client.index_list() - tasks = [] - for item in index_list: - tasks.append(asyncio.create_task(client.index_drop(namespace="test", name=item["id"]["name"]))) - - await asyncio.gather(*tasks) - - - -@pytest.fixture(scope="module") -async def session_admin_client( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - is_loadbalancer, -): - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - client = AdminClient( - seeds=types.HostPort(host=host, port=port), - is_loadbalancer=is_loadbalancer, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override, - username=username, - password=password, - ) - - yield client - await client.close() - - -@pytest.fixture(scope="module") -async def session_vector_client( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - is_loadbalancer, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - client = Client( - seeds=types.HostPort(host=host, port=port), - is_loadbalancer=is_loadbalancer, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override, - username=username, - password=password, - ) - yield client - await client.close() - - -@pytest.fixture -async def function_admin_client( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - is_loadbalancer, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - client = AdminClient( - seeds=types.HostPort(host=host, port=port), - is_loadbalancer=is_loadbalancer, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override, - username=username, - password=password, - ) - yield client - await client.close() - - -@pytest.fixture() -def index_name(): - length = random.randint(1, 15) - return ''.join(random.choices(string.ascii_letters + string.digits, k=length)) - - -@pytest.fixture(params=[DEFAULT_INDEX_ARGS]) -async def index(session_admin_client, index_name, request): - args = request.param - namespace = args.get("namespace", DEFAULT_NAMESPACE) - vector_field = args.get("vector_field", DEFAULT_VECTOR_FIELD) - dimensions = args.get("dimensions", DEFAULT_INDEX_DIMENSION) - await session_admin_client.index_create( - name = index_name, - namespace = namespace, - vector_field = vector_field, - dimensions = dimensions, - index_params=types.HnswParams( - batching_params=types.HnswBatchingParams( - # 10_000 is the minimum value, in order for the tests to run as - # fast as possible we set it to the minimum value so records are indexed - # quickly - index_interval=10_000, - ), - healer_params=types.HnswHealerParams( - # run the healer every second - # for fast indexing - schedule="* * * * * ?" - ) - ) - ) - yield index_name - try: - await session_admin_client.index_drop(namespace=namespace, name=index_name) - except AVSServerError as se: - if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: - pass - else: - raise - - -@pytest.fixture(params=[DEFAULT_RECORDS_ARGS]) -async def records(session_vector_client, request): - args = request.param - record_generator = args.get("record_generator", DEFAULT_RECORD_GENERATOR) - namespace = args.get("namespace", DEFAULT_NAMESPACE) - num_records = args.get("num_records", DEFAULT_NUM_RECORDS) - vector_field = args.get("vector_field", DEFAULT_VECTOR_FIELD) - dimensions = args.get("dimensions", DEFAULT_INDEX_DIMENSION) - set_name = args.get("set_name", None) - keys = [] - for key, rec in record_generator(count=num_records, vec_bin=vector_field, vec_dim=dimensions): - await session_vector_client.upsert( - namespace=namespace, - key=key, - record_data=rec, - set_name=set_name, - ) - keys.append(key) - yield keys - for key in keys: - await session_vector_client.delete(key=key, namespace=namespace) - - -@pytest.fixture(params=[DEFAULT_RECORDS_ARGS]) -async def record(session_vector_client, request): - args = request.param - record_generator = args.get("record_generator", DEFAULT_RECORD_GENERATOR) - namespace = args.get("namespace", DEFAULT_NAMESPACE) - vector_field = args.get("vector_field", DEFAULT_VECTOR_FIELD) - dimensions = args.get("dimensions", DEFAULT_INDEX_DIMENSION) - set_name = args.get("set_name", None) - key, rec = next(record_generator(count=1, vec_bin=vector_field, vec_dim=dimensions)) - await session_vector_client.upsert( - namespace=namespace, - key=key, - record_data=rec, - set_name=set_name, - ) - yield key - await session_vector_client.delete(key=key, namespace=namespace) \ No newline at end of file diff --git a/tests/standard/aio/requirements.txt b/tests/standard/aio/requirements.txt deleted file mode 100644 index 483b5812..00000000 --- a/tests/standard/aio/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -numpy==1.26.4 -pytest==7.4.0 -pytest-aio==1.5.0 -.. \ No newline at end of file diff --git a/tests/standard/aio/test_admin_client_index_create.py b/tests/standard/aio/test_admin_client_index_create.py deleted file mode 100644 index 0897afca..00000000 --- a/tests/standard/aio/test_admin_client_index_create.py +++ /dev/null @@ -1,615 +0,0 @@ -import pytest -from aerospike_vector_search import types, AVSServerError -import grpc - -from ...utils import random_name, DEFAULT_NAMESPACE - -from .aio_utils import drop_specified_index -from hypothesis import given, settings, Verbosity, Phase - -server_defaults = { - "m": 16, - "ef_construction": 100, - "ef": 100, - "batching_params": { - "max_index_records": 10000, - "index_interval": 10000, - } -} - -class index_create_test_case: - def __init__( - self, - *, - namespace, - vector_field, - dimensions, - vector_distance_metric, - sets, - index_params, - index_labels, - index_storage, - timeout - ): - self.namespace = namespace - self.vector_field = vector_field - self.dimensions = dimensions - if vector_distance_metric == None: - self.vector_distance_metric = types.VectorDistanceMetric.SQUARED_EUCLIDEAN - else: - self.vector_distance_metric = vector_distance_metric - self.sets = sets - self.index_params = index_params - self.index_labels = index_labels - self.index_storage = index_storage - self.timeout = timeout - - -#@given(random_name=index_strategy()) -#@settings(max_examples=1, deadline=1000, phases=(Phase.generate,)) -@pytest.mark.parametrize( - "test_case", - [ - index_create_test_case( - namespace=DEFAULT_NAMESPACE, - vector_field="example_1", - dimensions=1024, - vector_distance_metric=None, - sets=None, - index_params=None, - index_labels=None, - index_storage=None, - timeout=None, - ), - ], -) -async def test_index_create(session_admin_client, test_case, random_name): - if test_case == None: - return - await session_admin_client.index_create( - namespace=test_case.namespace, - name=random_name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - vector_distance_metric=test_case.vector_distance_metric, - sets=test_case.sets, - index_params=test_case.index_params, - index_labels=test_case.index_labels, - index_storage=test_case.index_storage, - timeout=test_case.timeout, - ) - results = await session_admin_client.index_list() - found = False - for result in results: - if result["id"]["name"] == random_name: - found = True - assert result["id"]["namespace"] == test_case.namespace - assert result["dimensions"] == test_case.dimensions - assert result["field"] == test_case.vector_field - assert result["hnsw_params"]["m"] == 16 - assert result["hnsw_params"]["ef_construction"] == 100 - assert result["hnsw_params"]["ef"] == 100 - assert result["hnsw_params"]["batching_params"]["max_index_records"] == 100000 - assert result["hnsw_params"]["batching_params"]["index_interval"] == 30000 - assert result["hnsw_params"]["batching_params"]["max_reindex_records"] == max(100000 / 10, 1000) - assert result["hnsw_params"]["batching_params"]["reindex_interval"] == 30000 - assert result["storage"]["namespace"] == test_case.namespace - assert result["storage"]["set_name"] == random_name - assert found == True - await drop_specified_index(session_admin_client, test_case.namespace, random_name) - - -#@given(random_name=index_strategy()) -#@settings(max_examples=1, deadline=1000, phases=(Phase.generate,)) -@pytest.mark.parametrize( - "test_case", - [ - index_create_test_case( - namespace=DEFAULT_NAMESPACE, - vector_field="example_2", - dimensions=495, - vector_distance_metric=None, - sets=None, - index_params=None, - index_labels=None, - index_storage=None, - timeout=None, - ), - index_create_test_case( - namespace=DEFAULT_NAMESPACE, - vector_field="example_3", - dimensions=2048, - vector_distance_metric=None, - sets=None, - index_params=None, - index_labels=None, - index_storage=None, - timeout=None, - ), - ], -) -async def test_index_create_with_dimnesions( - session_admin_client, test_case, random_name -): - await session_admin_client.index_create( - namespace=test_case.namespace, - name=random_name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - vector_distance_metric=test_case.vector_distance_metric, - sets=test_case.sets, - index_params=test_case.index_params, - index_labels=test_case.index_labels, - index_storage=test_case.index_storage, - timeout=test_case.timeout, - ) - - results = await session_admin_client.index_list() - - found = False - for result in results: - - if result["id"]["name"] == random_name: - found = True - assert result["id"]["namespace"] == test_case.namespace - assert result["dimensions"] == test_case.dimensions - assert result["field"] == test_case.vector_field - assert result["hnsw_params"]["m"] == 16 - assert result["hnsw_params"]["ef_construction"] == 100 - assert result["hnsw_params"]["ef"] == 100 - assert result["hnsw_params"]["batching_params"]["max_index_records"] == 100000 - assert result["hnsw_params"]["batching_params"]["index_interval"] == 30000 - assert result["hnsw_params"]["batching_params"]["max_reindex_records"] == max(100000 / 10, 1000) - assert result["hnsw_params"]["batching_params"]["reindex_interval"] == 30000 - assert result["storage"]["namespace"] == test_case.namespace - assert result["storage"]["set_name"] == random_name - assert found == True - - await drop_specified_index(session_admin_client, test_case.namespace, random_name) - - -#@given(random_name=index_strategy()) -#@settings(max_examples=1, deadline=1000, phases=(Phase.generate,)) -@pytest.mark.parametrize( - "test_case", - [ - index_create_test_case( - namespace=DEFAULT_NAMESPACE, - vector_field="example_4", - dimensions=1024, - vector_distance_metric=types.VectorDistanceMetric.COSINE, - sets=None, - index_params=None, - index_labels=None, - index_storage=None, - timeout=None, - ), - index_create_test_case( - namespace=DEFAULT_NAMESPACE, - vector_field="example_5", - dimensions=1024, - vector_distance_metric=types.VectorDistanceMetric.DOT_PRODUCT, - sets=None, - index_params=None, - index_labels=None, - index_storage=None, - timeout=None, - ), - index_create_test_case( - namespace=DEFAULT_NAMESPACE, - vector_field="example_6", - dimensions=1024, - vector_distance_metric=types.VectorDistanceMetric.MANHATTAN, - sets=None, - index_params=None, - index_labels=None, - index_storage=None, - timeout=None, - ), - index_create_test_case( - namespace=DEFAULT_NAMESPACE, - vector_field="example_7", - dimensions=1024, - vector_distance_metric=types.VectorDistanceMetric.HAMMING, - sets=None, - index_params=None, - index_labels=None, - index_storage=None, - timeout=None, - ), - ], -) -async def test_index_create_with_vector_distance_metric( - session_admin_client, test_case, random_name -): - - await session_admin_client.index_create( - namespace=test_case.namespace, - name=random_name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - vector_distance_metric=test_case.vector_distance_metric, - sets=test_case.sets, - index_params=test_case.index_params, - index_labels=test_case.index_labels, - index_storage=test_case.index_storage, - timeout=test_case.timeout, - ) - results = await session_admin_client.index_list() - found = False - for result in results: - if result["id"]["name"] == random_name: - found = True - assert result["id"]["namespace"] == test_case.namespace - assert result["dimensions"] == test_case.dimensions - assert result["field"] == test_case.vector_field - assert result["hnsw_params"]["m"] == 16 - assert result["hnsw_params"]["ef_construction"] == 100 - assert result["hnsw_params"]["ef"] == 100 - assert result["hnsw_params"]["batching_params"]["max_index_records"] == 100000 - assert result["hnsw_params"]["batching_params"]["index_interval"] == 30000 - assert result["hnsw_params"]["batching_params"]["max_reindex_records"] == max(100000 / 10, 1000) - assert result["hnsw_params"]["batching_params"]["reindex_interval"] == 30000 - assert result["storage"]["namespace"] == test_case.namespace - assert result["storage"]["set_name"] == random_name - assert found == True - await drop_specified_index(session_admin_client, test_case.namespace, random_name) - - -#@given(random_name=index_strategy()) -#@settings(max_examples=1, deadline=1000, phases=(Phase.generate,)) -@pytest.mark.parametrize( - "test_case", - [ - index_create_test_case( - namespace=DEFAULT_NAMESPACE, - vector_field="example_8", - dimensions=1024, - vector_distance_metric=None, - sets="Demo", - index_params=None, - index_labels=None, - index_storage=None, - timeout=None, - ), - index_create_test_case( - namespace=DEFAULT_NAMESPACE, - vector_field="example_9", - dimensions=1024, - vector_distance_metric=None, - sets="Cheese", - index_params=None, - index_labels=None, - index_storage=None, - timeout=None, - ), - ], -) -async def test_index_create_with_sets(session_admin_client, test_case, random_name): - - await session_admin_client.index_create( - namespace=test_case.namespace, - name=random_name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - vector_distance_metric=test_case.vector_distance_metric, - sets=test_case.sets, - index_params=test_case.index_params, - index_labels=test_case.index_labels, - index_storage=test_case.index_storage, - timeout=test_case.timeout, - ) - results = await session_admin_client.index_list() - found = False - for result in results: - if result["id"]["name"] == random_name: - found = True - assert result["id"]["namespace"] == test_case.namespace - assert result["dimensions"] == test_case.dimensions - assert result["field"] == test_case.vector_field - assert result["hnsw_params"]["m"] == 16 - assert result["hnsw_params"]["ef_construction"] == 100 - assert result["hnsw_params"]["ef"] == 100 - assert result["hnsw_params"]["batching_params"]["max_index_records"] == 100000 - assert result["hnsw_params"]["batching_params"]["index_interval"] == 30000 - assert result["hnsw_params"]["batching_params"]["max_reindex_records"] == max(100000 / 10, 1000) - assert result["hnsw_params"]["batching_params"]["reindex_interval"] == 30000 - assert result["storage"]["namespace"] == test_case.namespace - assert result["storage"]["set_name"] == random_name - assert found == True - await drop_specified_index(session_admin_client, test_case.namespace, random_name) - - -#@given(random_name=index_strategy()) -#@settings(max_examples=1, deadline=1000, phases=(Phase.generate,)) -@pytest.mark.parametrize( - "test_case", - [ - index_create_test_case( - namespace=DEFAULT_NAMESPACE, - vector_field="example_10", - dimensions=1024, - vector_distance_metric=None, - sets=None, - index_params=types.HnswParams( - m=32, - ef_construction=200, - ef=400, - enable_vector_integrity_check= False, - ), - index_labels=None, - index_storage=None, - timeout=None, - ), - index_create_test_case( - namespace=DEFAULT_NAMESPACE, - vector_field="example_11", - dimensions=1024, - vector_distance_metric=None, - sets=None, - index_params=types.HnswParams( - m=8, - ef_construction=50, - ef=25, - enable_vector_integrity_check= True - ), - index_labels=None, - index_storage=None, - timeout=None, - ), - index_create_test_case( - namespace=DEFAULT_NAMESPACE, - vector_field="example_12", - dimensions=1024, - vector_distance_metric=None, - sets=None, - index_params=types.HnswParams( - m=8, - enable_vector_integrity_check= True, - ), - index_labels=None, - index_storage=None, - timeout=None, - ), - index_create_test_case( - namespace=DEFAULT_NAMESPACE, - vector_field="example_13", - dimensions=1024, - vector_distance_metric=None, - sets=None, - index_params=types.HnswParams( - batching_params=types.HnswBatchingParams(max_index_records=2000, index_interval=20000, max_reindex_records=1500, reindex_interval=70000), - enable_vector_integrity_check= True - ), - index_labels=None, - index_storage=None, - timeout=None, - ), - index_create_test_case( - namespace=DEFAULT_NAMESPACE, - vector_field="example_20", - dimensions=1024, - vector_distance_metric=None, - sets="demo", - index_params=types.HnswParams( - index_caching_params=types.HnswCachingParams(max_entries=10, expiry=3000), - healer_params=types.HnswHealerParams( - max_scan_rate_per_node=80, - max_scan_page_size=40, - re_index_percent=50, - schedule="* 0/5 * ? * * *", - parallelism=4, - ), - merge_params=types.HnswIndexMergeParams( - index_parallelism=10, - reindex_parallelism=3 - ), - enable_vector_integrity_check= True - ), - index_labels=None, - index_storage=None, - timeout=None, - ), - ], -) -async def test_index_create_with_index_params( - session_admin_client, test_case, random_name -): - await session_admin_client.index_create( - namespace=test_case.namespace, - name=random_name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - vector_distance_metric=test_case.vector_distance_metric, - sets=test_case.sets, - index_params=test_case.index_params, - index_labels=test_case.index_labels, - index_storage=test_case.index_storage, - timeout=test_case.timeout, - ) - results = await session_admin_client.index_list() - found = False - for result in results: - if result["id"]["name"] == random_name: - found = True - assert result["id"]["namespace"] == test_case.namespace - assert result["dimensions"] == test_case.dimensions - assert result["field"] == test_case.vector_field - assert result["hnsw_params"]["m"] == test_case.index_params.m or server_defaults - assert ( - result["hnsw_params"]["ef_construction"] - == test_case.index_params.ef_construction or server_defaults - ) - assert result["hnsw_params"]["ef"] == test_case.index_params.ef or server_defaults - assert result["hnsw_params"][ - "enable_vector_integrity_check"] == test_case.index_params.enable_vector_integrity_check - assert ( - result["hnsw_params"]["batching_params"]["max_index_records"] - == test_case.index_params.batching_params.max_index_records or server_defaults - ) - assert ( - result["hnsw_params"]["batching_params"]["index_interval"] - == test_case.index_params.batching_params.index_interval or server_defaults - ) - assert ( - result["hnsw_params"]["batching_params"]["max_reindex_records"] - == test_case.index_params.batching_params.max_reindex_records or server_defaults - ) - - assert ( - result["hnsw_params"]["batching_params"]["reindex_interval"] - == test_case.index_params.batching_params.reindex_interval or server_defaults - ) - - assert result["storage"]["namespace"] == test_case.namespace - assert result["storage"]["set_name"] == random_name - assert found == True - await drop_specified_index(session_admin_client, test_case.namespace, random_name) - - -#@given(random_name=index_strategy()) -#@settings(max_examples=1, deadline=1000, phases=(Phase.generate,)) -@pytest.mark.parametrize( - "test_case", - [ - index_create_test_case( - namespace=DEFAULT_NAMESPACE, - vector_field="example_14", - dimensions=1024, - vector_distance_metric=None, - sets=None, - index_params=None, - index_labels={"size": "large", "price": "$4.99", "currencyType": "CAN"}, - index_storage=None, - timeout=None, - ), - ], -) -async def test_index_create_index_labels(session_admin_client, test_case, random_name): - if test_case == None: - return - await session_admin_client.index_create( - namespace=test_case.namespace, - name=random_name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - vector_distance_metric=test_case.vector_distance_metric, - sets=test_case.sets, - index_params=test_case.index_params, - index_labels=test_case.index_labels, - index_storage=test_case.index_storage, - timeout=test_case.timeout, - ) - results = await session_admin_client.index_list() - found = False - for result in results: - if result["id"]["name"] == random_name: - found = True - assert result["id"]["namespace"] == test_case.namespace - assert result["dimensions"] == test_case.dimensions - assert result["field"] == test_case.vector_field - assert result["hnsw_params"]["m"] == 16 - assert result["hnsw_params"]["ef_construction"] == 100 - assert result["hnsw_params"]["ef"] == 100 - assert result["hnsw_params"]["batching_params"]["max_index_records"] == 100000 - assert result["hnsw_params"]["batching_params"]["index_interval"] == 30000 - assert result["hnsw_params"]["batching_params"]["max_reindex_records"] == max(100000 / 10, 1000) - assert result["hnsw_params"]["batching_params"]["reindex_interval"] == 30000 - assert result["storage"]["namespace"] == test_case.namespace - assert result["storage"]["set_name"] == random_name - assert found == True - await drop_specified_index(session_admin_client, test_case.namespace, random_name) - - -#@given(random_name=index_strategy()) -#@settings(max_examples=1, deadline=1000, phases=(Phase.generate,)) -@pytest.mark.parametrize( - "test_case", - [ - index_create_test_case( - namespace=DEFAULT_NAMESPACE, - vector_field="example_15", - dimensions=1024, - vector_distance_metric=None, - sets=None, - index_params=None, - index_labels=None, - index_storage=types.IndexStorage(namespace=DEFAULT_NAMESPACE, set_name="foo"), - timeout=None, - ), - ], -) -async def test_index_create_index_storage(session_admin_client, test_case, random_name): - await session_admin_client.index_create( - namespace=test_case.namespace, - name=random_name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - vector_distance_metric=test_case.vector_distance_metric, - sets=test_case.sets, - index_params=test_case.index_params, - index_labels=test_case.index_labels, - index_storage=test_case.index_storage, - timeout=test_case.timeout, - ) - results = await session_admin_client.index_list() - found = False - for result in results: - if result["id"]["name"] == random_name: - found = True - assert result["id"]["namespace"] == test_case.namespace - assert result["dimensions"] == test_case.dimensions - assert result["field"] == test_case.vector_field - assert result["hnsw_params"]["m"] == 16 - assert result["hnsw_params"]["ef_construction"] == 100 - assert result["hnsw_params"]["ef"] == 100 - assert result["hnsw_params"]["batching_params"]["max_index_records"] == 100000 - assert result["hnsw_params"]["batching_params"]["index_interval"] == 30000 - assert result["hnsw_params"]["batching_params"]["max_reindex_records"] == max(100000 / 10, 1000) - assert result["hnsw_params"]["batching_params"]["reindex_interval"] == 30000 - assert result["storage"]["namespace"] == test_case.index_storage.namespace - assert result["storage"]["set_name"] == test_case.index_storage.set_name - assert found == True - - -#@given(random_name=index_strategy()) -#@settings(max_examples=1, deadline=1000, phases=(Phase.generate,)) -@pytest.mark.parametrize( - "test_case", - [ - index_create_test_case( - namespace=DEFAULT_NAMESPACE, - vector_field="example_16", - dimensions=1024, - vector_distance_metric=None, - sets=None, - index_params=None, - index_labels=None, - index_storage=None, - timeout=0.0001, - ), - ], -) -async def test_index_create_timeout( - session_admin_client, test_case, random_name, with_latency -): - - if not with_latency: - pytest.skip("Server latency too low to test timeout") - - with pytest.raises(AVSServerError) as e_info: - for i in range(10): - - await session_admin_client.index_create( - namespace=test_case.namespace, - name=random_name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - vector_distance_metric=test_case.vector_distance_metric, - sets=test_case.sets, - index_params=test_case.index_params, - index_labels=test_case.index_labels, - index_storage=test_case.index_storage, - timeout=test_case.timeout, - ) - assert e_info.value.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED diff --git a/tests/standard/aio/test_admin_client_index_drop.py b/tests/standard/aio/test_admin_client_index_drop.py deleted file mode 100644 index 9128a5c8..00000000 --- a/tests/standard/aio/test_admin_client_index_drop.py +++ /dev/null @@ -1,39 +0,0 @@ -import pytest -from aerospike_vector_search import AVSServerError -import grpc - -from utils import DEFAULT_NAMESPACE - - -from hypothesis import given, settings, Verbosity - - -@pytest.mark.parametrize("empty_test_case", [None, None]) -#@given(random_name=index_strategy()) -#@settings(max_examples=1, deadline=2000) -async def test_index_drop(session_admin_client, empty_test_case, index): - await session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=index) - - result = session_admin_client.index_list() - result = await result - for index in result: - assert index["id"]["name"] != index - - -@pytest.mark.parametrize("empty_test_case", [None, None]) -#@given(random_name=index_strategy()) -#@settings(max_examples=1, deadline=1000) -async def test_index_drop_timeout( - session_admin_client, empty_test_case, index, with_latency -): - if not with_latency: - pytest.skip("Server latency too low to test timeout") - - with pytest.raises(AVSServerError) as e_info: - for i in range(10): - - await session_admin_client.index_drop( - namespace=DEFAULT_NAMESPACE, name=index, timeout=0.0001 - ) - - assert e_info.value.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED diff --git a/tests/standard/aio/test_admin_client_index_get.py b/tests/standard/aio/test_admin_client_index_get.py deleted file mode 100644 index ff21b3e9..00000000 --- a/tests/standard/aio/test_admin_client_index_get.py +++ /dev/null @@ -1,109 +0,0 @@ -import pytest -from ...utils import DEFAULT_NAMESPACE, DEFAULT_INDEX_DIMENSION, DEFAULT_VECTOR_FIELD - -from hypothesis import given, settings, Verbosity - -from aerospike_vector_search import AVSServerError -import grpc - - -@pytest.mark.parametrize("empty_test_case", [None]) -#@given(random_name=index_strategy()) -#@settings(max_examples=1, deadline=1000) -async def test_index_get(session_admin_client, empty_test_case, index): - result = await session_admin_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=True) - - assert result["id"]["name"] == index - assert result["id"]["namespace"] == DEFAULT_NAMESPACE - assert result["dimensions"] == DEFAULT_INDEX_DIMENSION - assert result["field"] == DEFAULT_VECTOR_FIELD - assert result["hnsw_params"]["m"] == 16 - assert result["hnsw_params"]["ef_construction"] == 100 - assert result["hnsw_params"]["ef"] == 100 - assert result["hnsw_params"]["batching_params"]["max_index_records"] == 100000 - assert result["hnsw_params"]["batching_params"]["index_interval"] == 10000 - assert result["hnsw_params"]["batching_params"]["max_reindex_records"] == max(100000 / 10, 1000) - assert result["hnsw_params"]["batching_params"]["reindex_interval"] == 30000 - assert result["storage"]["namespace"] == DEFAULT_NAMESPACE - assert result["storage"].set_name == index - assert result["storage"]["set_name"] == index - - # Defaults - assert result["sets"] == "" - assert result["vector_distance_metric"] == 0 - - assert result["hnsw_params"]["max_mem_queue_size"] == 1000000 - assert result["hnsw_params"]["index_caching_params"]["max_entries"] == 2000000 - assert result["hnsw_params"]["index_caching_params"]["expiry"] == 3600000 - - assert result["hnsw_params"]["healer_params"]["max_scan_rate_per_node"] == 1000 - assert result["hnsw_params"]["healer_params"]["max_scan_page_size"] == 10000 - assert result["hnsw_params"]["healer_params"]["re_index_percent"] == 10.0 - assert result["hnsw_params"]["healer_params"]["schedule"] == "* * * * * ?" - assert result["hnsw_params"]["healer_params"]["parallelism"] == 1 - - # index parallelism and reindex parallelism are dynamic depending on the CPU cores of the host - # assert result["hnsw_params"]["merge_params"]["index_parallelism"] == 80 - # assert result["hnsw_params"]["merge_params"]["reindex_parallelism"] == 26 - - -@pytest.mark.parametrize("empty_test_case", [None]) -#@given(random_name=index_strategy()) -#@settings(max_examples=1, deadline=1000) -async def test_index_get_no_defaults(session_admin_client, empty_test_case, index): - result = await session_admin_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=False) - - assert result["id"]["name"] == index - assert result["id"]["namespace"] == DEFAULT_NAMESPACE - assert result["dimensions"] == DEFAULT_INDEX_DIMENSION - assert result["field"] == DEFAULT_VECTOR_FIELD - - # Defaults - assert result["sets"] == "" - assert result["vector_distance_metric"] == 0 - - assert result["hnsw_params"]["m"] == 0 - assert result["hnsw_params"]["ef"] == 0 - assert result["hnsw_params"]["ef_construction"] == 0 - assert result["hnsw_params"]["batching_params"]["max_index_records"] == 0 - # This is set by default to 10000 in the index fixture - assert result["hnsw_params"]["batching_params"]["index_interval"] == 10000 - assert result["hnsw_params"]["max_mem_queue_size"] == 0 - assert result["hnsw_params"]["index_caching_params"]["max_entries"] == 0 - assert result["hnsw_params"]["index_caching_params"]["expiry"] == 0 - - assert result["hnsw_params"]["healer_params"]["max_scan_rate_per_node"] == 0 - assert result["hnsw_params"]["healer_params"]["max_scan_page_size"] == 0 - assert result["hnsw_params"]["healer_params"]["re_index_percent"] == 0 - # This is set by default to * * * * * ? in the index fixture - assert result["hnsw_params"]["healer_params"]["schedule"] == "* * * * * ?" - assert result["hnsw_params"]["healer_params"]["parallelism"] == 0 - - assert result["hnsw_params"]["merge_params"]["index_parallelism"] == 0 - assert result["hnsw_params"]["merge_params"]["reindex_parallelism"] == 0 - - assert result["storage"]["namespace"] == "" - assert result["storage"].set_name == "" - assert result["storage"]["set_name"] == "" - - -@pytest.mark.parametrize("empty_test_case", [None]) -#@given(random_name=index_strategy()) -#@settings(max_examples=1, deadline=1000) -async def test_index_get_timeout( - session_admin_client, empty_test_case, index, with_latency -): - - if not with_latency: - pytest.skip("Server latency too low to test timeout") - - for i in range(10): - try: - result = await session_admin_client.index_get( - namespace=DEFAULT_NAMESPACE, name=index, timeout=0.0001 - ) - except AVSServerError as se: - if se.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: - assert se.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED - return - assert "In several attempts, the timeout did not happen" == "TEST FAIL" diff --git a/tests/standard/aio/test_admin_client_index_get_status.py b/tests/standard/aio/test_admin_client_index_get_status.py deleted file mode 100644 index fec3fc21..00000000 --- a/tests/standard/aio/test_admin_client_index_get_status.py +++ /dev/null @@ -1,38 +0,0 @@ -import pytest -from ...utils import DEFAULT_NAMESPACE - -from hypothesis import given, settings, Verbosity - -from aerospike_vector_search import types, AVSServerError -import grpc - - -@pytest.mark.parametrize("empty_test_case", [None]) -#@given(random_name=index_strategy()) -#@settings(max_examples=1, deadline=1000) -async def test_index_get_status(session_admin_client, empty_test_case, index): - result : types.IndexStatusResponse = await session_admin_client.index_get_status( - namespace=DEFAULT_NAMESPACE, name=index - ) - assert result.unmerged_record_count == 0 - - -@pytest.mark.parametrize("empty_test_case", [None]) -#@given(random_name=index_strategy()) -#@settings(max_examples=1, deadline=1000) -async def test_index_get_status_timeout( - session_admin_client, empty_test_case, index, with_latency -): - if not with_latency: - pytest.skip("Server latency too low to test timeout") - - for i in range(10): - try: - result = await session_admin_client.index_get_status( - namespace=DEFAULT_NAMESPACE, name=index, timeout=0.0001 - ) - except AVSServerError as se: - if se.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: - assert se.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED - return - assert "In several attempts, the timeout did not happen" == "TEST FAIL" diff --git a/tests/standard/aio/test_admin_client_index_list.py b/tests/standard/aio/test_admin_client_index_list.py deleted file mode 100644 index af633e8c..00000000 --- a/tests/standard/aio/test_admin_client_index_list.py +++ /dev/null @@ -1,42 +0,0 @@ -from aerospike_vector_search import AVSServerError - -import pytest -import grpc -from hypothesis import given, settings, Verbosity - - -@pytest.mark.parametrize("empty_test_case", [None]) -#@given(random_name=index_strategy()) -#@settings(max_examples=1, deadline=1000) -async def test_index_list(session_admin_client, empty_test_case, index): - result = await session_admin_client.index_list(apply_defaults=True) - assert len(result) > 0 - for index in result: - assert isinstance(index["id"]["name"], str) - assert isinstance(index["id"]["namespace"], str) - assert isinstance(index["dimensions"], int) - assert isinstance(index["field"], str) - assert isinstance(index["hnsw_params"]["m"], int) - assert isinstance(index["hnsw_params"]["ef_construction"], int) - assert isinstance(index["hnsw_params"]["ef"], int) - assert isinstance(index["hnsw_params"]["batching_params"]["max_index_records"], int) - assert isinstance(index["hnsw_params"]["batching_params"]["index_interval"], int) - assert isinstance(index["hnsw_params"]["batching_params"]["max_reindex_records"], int) - assert isinstance(index["hnsw_params"]["batching_params"]["reindex_interval"], int) - assert isinstance(index["storage"]["namespace"], str) - assert isinstance(index["storage"]["set_name"], str) - - -async def test_index_list_timeout(session_admin_client, with_latency): - if not with_latency: - pytest.skip("Server latency too low to test timeout") - for i in range(10): - try: - result = await session_admin_client.index_list(timeout=0.0001) - - except AVSServerError as se: - if se.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: - assert se.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED - return - - assert "In several attempts, the timeout did not happen" == "TEST FAIL" diff --git a/tests/standard/aio/test_admin_client_index_update.py b/tests/standard/aio/test_admin_client_index_update.py deleted file mode 100644 index 734b1b77..00000000 --- a/tests/standard/aio/test_admin_client_index_update.py +++ /dev/null @@ -1,110 +0,0 @@ -import time - -from aerospike_vector_search import types -from utils import DEFAULT_NAMESPACE - -import pytest - -server_defaults = { - "m": 16, - "ef_construction": 100, - "ef": 100, - "batching_params": { - "max_index_records": 10000, - "index_interval": 10000, - } -} - - -class index_update_test_case: - def __init__( - self, - *, - vector_field, - dimensions, - initial_labels, - update_labels, - hnsw_index_update, - timeout - ): - self.vector_field = vector_field - self.dimensions = dimensions - self.initial_labels = initial_labels - self.update_labels = update_labels - self.hnsw_index_update = hnsw_index_update - self.timeout = timeout - - -@pytest.mark.parametrize( - "test_case", - [ - index_update_test_case( - vector_field="update_2", - dimensions=256, - initial_labels={"status": "active"}, - update_labels={"status": "inactive", "region": "us-west"}, - hnsw_index_update=types.HnswIndexUpdate( - batching_params=types.HnswBatchingParams( - max_index_records=2000, - index_interval=20000, - max_reindex_records=1500, - reindex_interval=70000 - ), - max_mem_queue_size=1000030, - index_caching_params=types.HnswCachingParams(max_entries=10, expiry=3000), - merge_params=types.HnswIndexMergeParams(index_parallelism=10, reindex_parallelism=3), - healer_params=types.HnswHealerParams(max_scan_rate_per_node=80), - enable_vector_integrity_check=False, - ), - timeout=None, - ), - ], -) -async def test_index_update_async(session_admin_client, test_case, index): - # Update the index with new labels and parameters - await session_admin_client.index_update( - namespace=DEFAULT_NAMESPACE, - name=index, - index_labels=test_case.update_labels, - hnsw_update_params=test_case.hnsw_index_update - ) - - # Allow time for update to be applied - time.sleep(10) - - # Verify the update - result = await session_admin_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=True) - assert result, "Expected result to be non-empty but got an empty dictionary." - - assert result["id"]["namespace"] == DEFAULT_NAMESPACE - - # Assertions based on provided parameters - if test_case.hnsw_index_update.batching_params: - assert result["hnsw_params"]["batching_params"][ - "max_index_records"] == test_case.hnsw_index_update.batching_params.max_index_records - assert result["hnsw_params"]["batching_params"][ - "index_interval"] == test_case.hnsw_index_update.batching_params.index_interval - assert result["hnsw_params"]["batching_params"][ - "max_reindex_records"] == test_case.hnsw_index_update.batching_params.max_reindex_records - assert result["hnsw_params"]["batching_params"][ - "reindex_interval"] == test_case.hnsw_index_update.batching_params.reindex_interval - - assert result["hnsw_params"]["max_mem_queue_size"] == test_case.hnsw_index_update.max_mem_queue_size - - if test_case.hnsw_index_update.index_caching_params: - assert result["hnsw_params"]["index_caching_params"][ - "max_entries"] == test_case.hnsw_index_update.index_caching_params.max_entries - assert result["hnsw_params"]["index_caching_params"][ - "expiry"] == test_case.hnsw_index_update.index_caching_params.expiry - - if test_case.hnsw_index_update.merge_params: - assert result["hnsw_params"]["merge_params"][ - "index_parallelism"] == test_case.hnsw_index_update.merge_params.index_parallelism - assert result["hnsw_params"]["merge_params"][ - "reindex_parallelism"] == test_case.hnsw_index_update.merge_params.reindex_parallelism - - if test_case.hnsw_index_update.healer_params: - assert result["hnsw_params"]["healer_params"][ - "max_scan_rate_per_node"] == test_case.hnsw_index_update.healer_params.max_scan_rate_per_node - - assert result["hnsw_params"]["enable_vector_integrity_check"] == test_case.hnsw_index_update.enable_vector_integrity_check diff --git a/tests/standard/aio/test_extensive_vector_search.py b/tests/standard/aio/test_extensive_vector_search.py deleted file mode 100644 index a937dce2..00000000 --- a/tests/standard/aio/test_extensive_vector_search.py +++ /dev/null @@ -1,452 +0,0 @@ -import numpy as np -import asyncio -import pytest -import time -from aerospike_vector_search import types, AVSServerError - -import grpc - -dimensions = 128 -truth_vector_dimensions = 100 -base_vector_number = 10_000 -query_vector_number = 100 - - -# Print the current working directory -def parse_sift_to_numpy_array(length, dim, byte_buffer, dtype): - numpy = np.empty((length,), dtype=object) - - record_length = (dim * 4) + 4 - - for i in range(length): - current_offset = i * record_length - begin = current_offset - vector_begin = current_offset + 4 - end = current_offset + record_length - if np.frombuffer(byte_buffer[begin:vector_begin], dtype=np.int32)[0] != dim: - raise Exception("Failed to parse byte buffer correctly") - numpy[i] = np.frombuffer(byte_buffer[vector_begin:end], dtype=dtype) - return numpy - - -@pytest.fixture -def base_numpy(): - base_filename = "siftsmall/siftsmall_base.fvecs" - with open(base_filename, "rb") as file: - base_bytes = bytearray(file.read()) - - base_numpy = parse_sift_to_numpy_array( - base_vector_number, dimensions, base_bytes, np.float32 - ) - - return base_numpy - - -@pytest.fixture -def truth_numpy(): - truth_filename = "siftsmall/siftsmall_groundtruth.ivecs" - with open(truth_filename, "rb") as file: - truth_bytes = bytearray(file.read()) - - truth_numpy = parse_sift_to_numpy_array( - query_vector_number, truth_vector_dimensions, truth_bytes, np.int32 - ) - - return truth_numpy - - -@pytest.fixture -def query_numpy(): - query_filename = "siftsmall/siftsmall_query.fvecs" - with open(query_filename, "rb") as file: - query_bytes = bytearray(file.read()) - - truth_numpy = parse_sift_to_numpy_array( - query_vector_number, dimensions, query_bytes, np.float32 - ) - - return truth_numpy - - -async def put_vector(client, vector, j, set_name): - await client.upsert( - namespace="test", - key=str(j), - record_data={"unit_test": vector}, - set_name=set_name, - ) - - -async def get_vector(client, j, set_name): - result = await client.get(namespace="test", key=str(j), set_name=set_name) - - -async def vector_search(client, vector, name): - result = await client.vector_search( - namespace="test", - index_name=name, - query=vector, - limit=100, - include_fields=["unit_test"], - ) - return result - - -async def vector_search_ef_80(client, vector, name): - result = await client.vector_search( - namespace="test", - index_name=name, - query=vector, - limit=100, - include_fields=["unit_test"], - search_params=types.HnswSearchParams(ef=80), - ) - return result - - -async def grade_results( - base_numpy, - truth_numpy, - query_numpy, - session_vector_client, - session_admin_client, - name, -): - # Vector search all query vectors - tasks = [] - count = 0 - for i in query_numpy: - if count % 2: - tasks.append(vector_search(session_vector_client, i, name)) - else: - tasks.append(vector_search_ef_80(session_vector_client, i, name)) - count += 1 - - results = await asyncio.gather(*tasks) - # Get recall numbers for each query - recall_for_each_query = [] - for i, outside in enumerate(truth_numpy): - true_positive = 0 - false_negative = 0 - # Parse all fields for each neighbor into an array - field_list = [] - - for j, result in enumerate(results[i]): - field_list.append(result.fields["unit_test"]) - - for j, index in enumerate(outside): - vector = base_numpy[index].tolist() - if vector in field_list: - true_positive = true_positive + 1 - else: - false_negative = false_negative + 1 - - recall = true_positive / (true_positive + false_negative) - recall_for_each_query.append(recall) - - # Calculate the sum of all values - recall_sum = sum(recall_for_each_query) - - # Calculate the average - average = recall_sum / len(recall_for_each_query) - - assert average > 0.95 - for recall in recall_for_each_query: - assert recall > 0.9 - - -async def test_vector_search( - base_numpy, - truth_numpy, - query_numpy, - session_vector_client, - session_admin_client, - extensive_vector_search, -): - - if not extensive_vector_search: - pytest.skip("Extensive vector tests disabled") - - await session_admin_client.index_create( - namespace="test", - name="demo1", - vector_field="unit_test", - dimensions=128, - ) - - # Put base vectors for search - tasks = [] - - for j, vector in enumerate(base_numpy): - tasks.append(put_vector(session_vector_client, vector, j, None)) - - tasks.append( - session_vector_client.wait_for_index_completion(namespace="test", name="demo1") - ) - await asyncio.gather(*tasks) - - # Wait for index completion isn't perfect - # give the index some extra time since accuracy is the point of this test - time.sleep(5) - - await grade_results( - base_numpy, - truth_numpy, - query_numpy, - session_vector_client, - session_admin_client, - name="demo1", - ) - - -async def test_vector_search_with_set_same_as_index( - base_numpy, - truth_numpy, - query_numpy, - session_vector_client, - session_admin_client, - extensive_vector_search, -): - - if not extensive_vector_search: - pytest.skip("Extensive vector tests disabled") - - await session_admin_client.index_create( - namespace="test", - name="demo2", - sets="demo2", - vector_field="unit_test", - dimensions=128, - index_storage=types.IndexStorage(namespace="test", set_name="demo2"), - ) - - # Put base vectors for search - tasks = [] - - for j, vector in enumerate(base_numpy): - tasks.append(put_vector(session_vector_client, vector, j, "demo2")) - - tasks.append( - session_vector_client.wait_for_index_completion(namespace="test", name="demo2") - ) - await asyncio.gather(*tasks) - - # Wait for index completion isn't perfect - # give the index some extra time since accuracy is the point of this test - time.sleep(5) - - await grade_results( - base_numpy, - truth_numpy, - query_numpy, - session_vector_client, - session_admin_client, - name="demo2", - ) - - -async def test_vector_search_with_set_different_than_name( - base_numpy, - truth_numpy, - query_numpy, - session_vector_client, - session_admin_client, - extensive_vector_search, -): - - if not extensive_vector_search: - pytest.skip("Extensive vector tests disabled") - - await session_admin_client.index_create( - namespace="test", - name="demo3", - vector_field="unit_test", - dimensions=128, - sets="example1", - index_storage=types.IndexStorage(namespace="test", set_name="demo3"), - ) - - # Put base vectors for search - tasks = [] - - for j, vector in enumerate(base_numpy): - tasks.append(put_vector(session_vector_client, vector, j, "example1")) - - tasks.append( - session_vector_client.wait_for_index_completion(namespace="test", name="demo3") - ) - await asyncio.gather(*tasks) - - # Wait for index completion isn't perfect - # give the index some extra time since accuracy is the point of this test - time.sleep(5) - - await grade_results( - base_numpy, - truth_numpy, - query_numpy, - session_vector_client, - session_admin_client, - name="demo3", - ) - - -async def test_vector_search_with_index_storage_different_than_name( - base_numpy, - truth_numpy, - query_numpy, - session_vector_client, - session_admin_client, - extensive_vector_search, -): - - if not extensive_vector_search: - pytest.skip("Extensive vector tests disabled") - - await session_admin_client.index_create( - namespace="test", - name="demo4", - vector_field="unit_test", - dimensions=128, - sets="demo4", - index_storage=types.IndexStorage(namespace="test", set_name="example2"), - ) - - # Put base vectors for search - tasks = [] - - for j, vector in enumerate(base_numpy): - tasks.append(put_vector(session_vector_client, vector, j, "demo4")) - - tasks.append( - session_vector_client.wait_for_index_completion(namespace="test", name="demo4") - ) - await asyncio.gather(*tasks) - - # Wait for index completion isn't perfect - # give the index some extra time since accuracy is the point of this test - time.sleep(5) - - await grade_results( - base_numpy, - truth_numpy, - query_numpy, - session_vector_client, - session_admin_client, - name="demo4", - ) - - -async def test_vector_search_with_index_storage_different_location( - base_numpy, - truth_numpy, - query_numpy, - session_vector_client, - session_admin_client, - extensive_vector_search, -): - - if not extensive_vector_search: - pytest.skip("Extensive vector tests disabled") - - await session_admin_client.index_create( - namespace="test", - name="demo5", - vector_field="unit_test", - dimensions=128, - sets="example3", - index_storage=types.IndexStorage(namespace="test", set_name="example4"), - ) - - # Put base vectors for search - tasks = [] - - for j, vector in enumerate(base_numpy): - tasks.append(put_vector(session_vector_client, vector, j, "example3")) - - tasks.append( - session_vector_client.wait_for_index_completion(namespace="test", name="demo5") - ) - await asyncio.gather(*tasks) - - # Wait for index completion isn't perfect - # give the index some extra time since accuracy is the point of this test - time.sleep(5) - - await grade_results( - base_numpy, - truth_numpy, - query_numpy, - session_vector_client, - session_admin_client, - name="demo5", - ) - - -async def test_vector_search_with_separate_namespace( - base_numpy, - truth_numpy, - query_numpy, - session_vector_client, - session_admin_client, - extensive_vector_search, -): - - if not extensive_vector_search: - pytest.skip("Extensive vector tests disabled") - - await session_admin_client.index_create( - namespace="test", - name="demo6", - vector_field="unit_test", - dimensions=128, - sets="demo6", - index_storage=types.IndexStorage(namespace="index_storage", set_name="demo6"), - ) - - # Put base vectors for search - tasks = [] - - for j, vector in enumerate(base_numpy): - tasks.append(put_vector(session_vector_client, vector, j, "demo6")) - - tasks.append( - session_vector_client.wait_for_index_completion(namespace="test", name="demo6") - ) - await asyncio.gather(*tasks) - - # Wait for index completion isn't perfect - # give the index some extra time since accuracy is the point of this test - time.sleep(5) - - await grade_results( - base_numpy, - truth_numpy, - query_numpy, - session_vector_client, - session_admin_client, - name="demo6", - ) - - -async def test_vector_vector_search_timeout( - session_vector_client, session_admin_client, with_latency -): - if not with_latency: - pytest.skip("Server latency too low to test timeout") - - for i in range(10): - try: - result = await session_vector_client.vector_search( - namespace="test", - index_name="demo2", - query=[0, 1, 2], - limit=100, - include_fields=["unit_test"], - timeout=0.0001, - ) - except AVSServerError as se: - if se.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: - assert se.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED - return - assert "In several attempts, the timeout did not happen" == "TEST FAIL" diff --git a/tests/standard/aio/test_service_config.py b/tests/standard/aio/test_service_config.py deleted file mode 100644 index bcc94eda..00000000 --- a/tests/standard/aio/test_service_config.py +++ /dev/null @@ -1,496 +0,0 @@ -# import pytest -# import time - -# import os -# import json - -# from aerospike_vector_search import AVSServerError, types -# from aerospike_vector_search.aio import AdminClient - - -# class service_config_parse_test_case: -# def __init__(self, *, service_config_path): -# self.service_config_path = service_config_path - - -# @pytest.mark.parametrize( -# "test_case", -# [ -# service_config_parse_test_case( -# service_config_path="service_configs/master.json" -# ), -# ], -# ) -# async def test_admin_client_service_config_parse( -# host, -# port, -# username, -# password, -# root_certificate, -# certificate_chain, -# private_key, -# ssl_target_name_override, -# test_case, -# ): - -# if root_certificate: -# with open(root_certificate, "rb") as f: -# root_certificate = f.read() - -# if certificate_chain: -# with open(certificate_chain, "rb") as f: -# certificate_chain = f.read() -# if private_key: -# with open(private_key, "rb") as f: -# private_key = f.read() - -# async with AdminClient( -# seeds=types.HostPort(host=host, port=port), -# username=username, -# password=password, -# root_certificate=root_certificate, -# certificate_chain=certificate_chain, -# private_key=private_key, -# ssl_target_name_override=ssl_target_name_override, -# service_config_path=test_case.service_config_path, -# ) as client: -# pass - - -# class service_config_test_case: -# def __init__( -# self, *, service_config_path, namespace, name, vector_field, dimensions -# ): - -# script_dir = os.path.dirname(os.path.abspath(__file__)) - -# self.service_config_path = os.path.abspath( -# os.path.join(script_dir, "..", "..", service_config_path) -# ) - -# with open(self.service_config_path, "rb") as f: -# self.service_config = json.load(f) - -# self.max_attempts = self.service_config["methodConfig"][0]["retryPolicy"][ -# "maxAttempts" -# ] -# self.initial_backoff = int( -# self.service_config["methodConfig"][0]["retryPolicy"]["initialBackoff"][:-1] -# ) -# self.max_backoff = int( -# self.service_config["methodConfig"][0]["retryPolicy"]["maxBackoff"][:-1] -# ) -# self.backoff_multiplier = self.service_config["methodConfig"][0]["retryPolicy"][ -# "backoffMultiplier" -# ] -# self.retryable_status_codes = self.service_config["methodConfig"][0][ -# "retryPolicy" -# ]["retryableStatusCodes"] -# self.namespace = namespace -# self.name = name -# self.vector_field = vector_field -# self.dimensions = dimensions - - -# def calculate_expected_time( -# max_attempts, -# initial_backoff, -# backoff_multiplier, -# max_backoff, -# retryable_status_codes, -# ): - -# current_backkoff = initial_backoff - -# expected_time = 0 -# for attempt in range(max_attempts - 1): -# expected_time += current_backkoff -# current_backkoff *= backoff_multiplier -# current_backkoff = min(current_backkoff, max_backoff) - -# return expected_time - - -# @pytest.mark.parametrize( -# "test_case", -# [ -# service_config_test_case( -# service_config_path="service_configs/retries.json", -# namespace="test", -# name="service_config_index_1", -# vector_field="example_1", -# dimensions=1024, -# ) -# ], -# ) -# async def test_admin_client_service_config_retries( -# host, -# port, -# username, -# password, -# root_certificate, -# certificate_chain, -# private_key, -# ssl_target_name_override, -# test_case, -# ): - -# if root_certificate: -# with open(root_certificate, "rb") as f: -# root_certificate = f.read() - -# if certificate_chain: -# with open(certificate_chain, "rb") as f: -# certificate_chain = f.read() -# if private_key: -# with open(private_key, "rb") as f: -# private_key = f.read() - -# async with AdminClient( -# seeds=types.HostPort(host=host, port=port), -# username=username, -# password=password, -# root_certificate=root_certificate, -# certificate_chain=certificate_chain, -# private_key=private_key, -# ssl_target_name_override=ssl_target_name_override, -# service_config_path=test_case.service_config_path, -# ) as client: -# try: -# await client.index_create( -# namespace=test_case.namespace, -# name=test_case.name, -# vector_field=test_case.vector_field, -# dimensions=test_case.dimensions, -# ) -# except: -# pass -# expected_time = calculate_expected_time( -# test_case.max_attempts, -# test_case.initial_backoff, -# test_case.backoff_multiplier, -# test_case.max_backoff, -# test_case.retryable_status_codes, -# ) -# start_time = time.time() - -# with pytest.raises(AVSServerError) as e_info: -# await client.index_create( -# namespace=test_case.namespace, -# name=test_case.name, -# vector_field=test_case.vector_field, -# dimensions=test_case.dimensions, -# ) - -# end_time = time.time() -# elapsed_time = end_time - start_time - -# assert abs(elapsed_time - expected_time) < 1.5 - - -# @pytest.mark.parametrize( -# "test_case", -# [ -# service_config_test_case( -# service_config_path="service_configs/initial_backoff.json", -# namespace="test", -# name="service_config_index_2", -# vector_field="example_1", -# dimensions=1024, -# ) -# ], -# ) -# async def test_admin_client_service_config_initial_backoff( -# host, -# port, -# username, -# password, -# root_certificate, -# certificate_chain, -# private_key, -# ssl_target_name_override, -# test_case, -# ): - -# if root_certificate: -# with open(root_certificate, "rb") as f: -# root_certificate = f.read() - -# if certificate_chain: -# with open(certificate_chain, "rb") as f: -# certificate_chain = f.read() -# if private_key: -# with open(private_key, "rb") as f: -# private_key = f.read() - -# async with AdminClient( -# seeds=types.HostPort(host=host, port=port), -# username=username, -# password=password, -# root_certificate=root_certificate, -# certificate_chain=certificate_chain, -# private_key=private_key, -# ssl_target_name_override=ssl_target_name_override, -# service_config_path=test_case.service_config_path, -# ) as client: - -# try: -# await client.index_create( -# namespace=test_case.namespace, -# name=test_case.name, -# vector_field=test_case.vector_field, -# dimensions=test_case.dimensions, -# ) -# except: -# pass - -# expected_time = calculate_expected_time( -# test_case.max_attempts, -# test_case.initial_backoff, -# test_case.backoff_multiplier, -# test_case.max_backoff, -# test_case.retryable_status_codes, -# ) -# start_time = time.time() - -# with pytest.raises(AVSServerError) as e_info: -# await client.index_create( -# namespace=test_case.namespace, -# name=test_case.name, -# vector_field=test_case.vector_field, -# dimensions=test_case.dimensions, -# ) - -# end_time = time.time() -# elapsed_time = end_time - start_time - -# assert abs(elapsed_time - expected_time) < 1.5 - - -# @pytest.mark.parametrize( -# "test_case", -# [ -# service_config_test_case( -# service_config_path="service_configs/max_backoff.json", -# namespace="test", -# name="service_config_index_3", -# vector_field="example_1", -# dimensions=1024, -# ), -# service_config_test_case( -# service_config_path="service_configs/max_backoff_lower_than_initial.json", -# namespace="test", -# name="service_config_index_4", -# vector_field="example_1", -# dimensions=1024, -# ), -# ], -# ) -# async def test_admin_client_service_config_max_backoff( -# host, -# port, -# username, -# password, -# root_certificate, -# certificate_chain, -# private_key, -# ssl_target_name_override, -# test_case, -# ): - -# if root_certificate: -# with open(root_certificate, "rb") as f: -# root_certificate = f.read() - -# if certificate_chain: -# with open(certificate_chain, "rb") as f: -# certificate_chain = f.read() -# if private_key: -# with open(private_key, "rb") as f: -# private_key = f.read() - -# async with AdminClient( -# seeds=types.HostPort(host=host, port=port), -# username=username, -# password=password, -# root_certificate=root_certificate, -# certificate_chain=certificate_chain, -# private_key=private_key, -# ssl_target_name_override=ssl_target_name_override, -# service_config_path=test_case.service_config_path, -# ) as client: - -# try: -# await client.index_create( -# namespace=test_case.namespace, -# name=test_case.name, -# vector_field=test_case.vector_field, -# dimensions=test_case.dimensions, -# ) -# except: -# pass -# expected_time = calculate_expected_time( -# test_case.max_attempts, -# test_case.initial_backoff, -# test_case.backoff_multiplier, -# test_case.max_backoff, -# test_case.retryable_status_codes, -# ) -# start_time = time.time() - -# with pytest.raises(AVSServerError) as e_info: -# await client.index_create( -# namespace=test_case.namespace, -# name=test_case.name, -# vector_field=test_case.vector_field, -# dimensions=test_case.dimensions, -# ) - -# end_time = time.time() -# elapsed_time = end_time - start_time -# assert abs(elapsed_time - expected_time) < 1.5 - - -# @pytest.mark.parametrize( -# "test_case", -# [ -# service_config_test_case( -# service_config_path="service_configs/backoff_multiplier.json", -# namespace="test", -# name="service_config_index_5", -# vector_field="example_1", -# dimensions=1024, -# ) -# ], -# ) -# async def test_admin_client_service_config_backoff_multiplier( -# host, -# port, -# username, -# password, -# root_certificate, -# certificate_chain, -# private_key, -# ssl_target_name_override, -# test_case, -# ): - -# if root_certificate: -# with open(root_certificate, "rb") as f: -# root_certificate = f.read() - -# if certificate_chain: -# with open(certificate_chain, "rb") as f: -# certificate_chain = f.read() -# if private_key: -# with open(private_key, "rb") as f: -# private_key = f.read() - -# async with AdminClient( -# seeds=types.HostPort(host=host, port=port), -# username=username, -# password=password, -# root_certificate=root_certificate, -# certificate_chain=certificate_chain, -# private_key=private_key, -# ssl_target_name_override=ssl_target_name_override, -# service_config_path=test_case.service_config_path, -# ) as client: - -# try: - -# await client.index_create( -# namespace=test_case.namespace, -# name=test_case.name, -# vector_field=test_case.vector_field, -# dimensions=test_case.dimensions, -# ) -# except: -# pass - -# expected_time = calculate_expected_time( -# test_case.max_attempts, -# test_case.initial_backoff, -# test_case.backoff_multiplier, -# test_case.max_backoff, -# test_case.retryable_status_codes, -# ) -# start_time = time.time() - -# with pytest.raises(AVSServerError) as e_info: -# await client.index_create( -# namespace=test_case.namespace, -# name=test_case.name, -# vector_field=test_case.vector_field, -# dimensions=test_case.dimensions, -# ) - -# end_time = time.time() -# elapsed_time = end_time - start_time -# assert abs(elapsed_time - expected_time) < 1.5 - - -# @pytest.mark.parametrize( -# "test_case", -# [ -# service_config_test_case( -# service_config_path="service_configs/retryable_status_codes.json", -# namespace="test", -# name="service_config_index_6", -# vector_field=None, -# dimensions=None, -# ) -# ], -# ) -# async def test_admin_client_service_config_retryable_status_codes( -# host, -# port, -# username, -# password, -# root_certificate, -# certificate_chain, -# private_key, -# ssl_target_name_override, -# test_case, -# ): - -# if root_certificate: -# with open(root_certificate, "rb") as f: -# root_certificate = f.read() - -# if certificate_chain: -# with open(certificate_chain, "rb") as f: -# certificate_chain = f.read() -# if private_key: -# with open(private_key, "rb") as f: -# private_key = f.read() - -# async with AdminClient( -# seeds=types.HostPort(host=host, port=port), -# username=username, -# password=password, -# root_certificate=root_certificate, -# certificate_chain=certificate_chain, -# private_key=private_key, -# ssl_target_name_override=ssl_target_name_override, -# service_config_path=test_case.service_config_path, -# ) as client: - -# expected_time = calculate_expected_time( -# test_case.max_attempts, -# test_case.initial_backoff, -# test_case.backoff_multiplier, -# test_case.max_backoff, -# test_case.retryable_status_codes, -# ) -# start_time = time.time() - -# with pytest.raises(AVSServerError) as e_info: -# await client.index_get_status( -# namespace=test_case.namespace, -# name=test_case.name, -# ) - -# end_time = time.time() -# elapsed_time = end_time - start_time -# assert abs(elapsed_time - expected_time) < 2 diff --git a/tests/standard/aio/test_vector_client_delete.py b/tests/standard/aio/test_vector_client_delete.py deleted file mode 100644 index f5479095..00000000 --- a/tests/standard/aio/test_vector_client_delete.py +++ /dev/null @@ -1,97 +0,0 @@ -import pytest -from aerospike_vector_search import AVSServerError -from utils import random_key, DEFAULT_NAMESPACE - -from hypothesis import given, settings, Verbosity -import grpc - - -class delete_test_case: - def __init__( - self, - *, - namespace, - set_name, - timeout, - ): - self.namespace = namespace - self.set_name = set_name - self.timeout = timeout - - -#@given(random_key=key_strategy()) -#@settings(max_examples=1, deadline=1000) -@pytest.mark.parametrize( - "test_case", - [ - delete_test_case( - namespace=DEFAULT_NAMESPACE, - set_name=None, - timeout=None, - ), - delete_test_case( - namespace=DEFAULT_NAMESPACE, - set_name=None, - timeout=None, - ), - ], -) -async def test_vector_delete(session_vector_client, test_case, record): - await session_vector_client.delete( - namespace=test_case.namespace, - key=record, - set_name=test_case.set_name, - timeout=test_case.timeout, - ) - with pytest.raises(AVSServerError) as e_info: - result = await session_vector_client.get( - namespace=test_case.namespace, key=record - ) - - -#@given(random_key=key_strategy()) -#@settings(max_examples=1, deadline=1000) -@pytest.mark.parametrize( - "test_case", - [ - delete_test_case( - namespace=DEFAULT_NAMESPACE, - set_name=None, - timeout=None, - ), - ], -) -async def test_vector_delete_without_record( - session_vector_client, test_case, random_key -): - await session_vector_client.delete( - namespace=test_case.namespace, - key=random_key, - ) - - -#@given(random_key=key_strategy()) -#@settings(max_examples=1, deadline=1000) -@pytest.mark.parametrize( - "test_case", - [ - None, - delete_test_case( - namespace=DEFAULT_NAMESPACE, - set_name=None, - timeout=0.0001, - ), - ], -) -async def test_vector_delete_timeout( - session_vector_client, test_case, record, with_latency -): - if not with_latency: - pytest.skip("Server latency too low to test timeout") - with pytest.raises(AVSServerError) as e_info: - for i in range(10): - - await session_vector_client.delete( - namespace=test_case.namespace, key=record, timeout=test_case.timeout - ) - assert e_info.value.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED diff --git a/tests/standard/aio/test_vector_client_exists.py b/tests/standard/aio/test_vector_client_exists.py deleted file mode 100644 index 42cdf5c9..00000000 --- a/tests/standard/aio/test_vector_client_exists.py +++ /dev/null @@ -1,68 +0,0 @@ -import pytest -import grpc - -from aerospike_vector_search import AVSServerError -from utils import DEFAULT_NAMESPACE - -from hypothesis import given, settings, Verbosity - - -class exists_test_case: - def __init__( - self, - *, - namespace, - set_name, - timeout, - ): - self.namespace = namespace - self.set_name = set_name - self.timeout = timeout - - -#@given(random_key=key_strategy()) -#@settings(max_examples=1, deadline=1000) -@pytest.mark.parametrize( - "test_case", - [ - exists_test_case( - namespace=DEFAULT_NAMESPACE, - set_name=None, - timeout=None, - ), - exists_test_case( - namespace=DEFAULT_NAMESPACE, - set_name=None, - timeout=None, - ), - ], -) -async def test_vector_exists(session_vector_client, test_case, record): - result = await session_vector_client.exists( - namespace=test_case.namespace, - key=record, - ) - assert result is True - - -#@given(random_key=key_strategy()) -#@settings(max_examples=1, deadline=1000) -@pytest.mark.parametrize( - "test_case", - [ - exists_test_case( - namespace="test", set_name=None, timeout=0.0001 - ), - ], -) -async def test_vector_exists_timeout( - session_vector_client, test_case, record, with_latency -): - if not with_latency: - pytest.skip("Server latency too low to test timeout") - with pytest.raises(AVSServerError) as e_info: - for i in range(10): - result = await session_vector_client.exists( - namespace=test_case.namespace, key=record, timeout=test_case.timeout - ) - assert e_info.value.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED diff --git a/tests/standard/aio/test_vector_client_get.py b/tests/standard/aio/test_vector_client_get.py deleted file mode 100644 index 305ca760..00000000 --- a/tests/standard/aio/test_vector_client_get.py +++ /dev/null @@ -1,159 +0,0 @@ -import pytest -import grpc - -from aerospike_vector_search import AVSServerError -from utils import DEFAULT_NAMESPACE, random_key - -from hypothesis import given, settings, Verbosity - - -class get_test_case: - def __init__( - self, - *, - namespace, - include_fields, - exclude_fields, - set_name, - expected_fields, - timeout, - ): - self.namespace = namespace - self.include_fields = include_fields - self.exclude_fields = exclude_fields - self.set_name = set_name - self.expected_fields = expected_fields - self.timeout = timeout - - -#@given(random_key=key_strategy()) -#@settings(max_examples=1, deadline=1000) -@pytest.mark.parametrize( - "record,test_case", - [ - ( - {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"skills": 1024}))}, - get_test_case( - namespace=DEFAULT_NAMESPACE, - include_fields=["skills"], - exclude_fields = None, - set_name=None, - expected_fields={"skills": 1024}, - timeout=None, - ), - ), - ( - {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": [float(i) for i in range(1024)]}))}, - get_test_case( - namespace=DEFAULT_NAMESPACE, - include_fields=["english"], - exclude_fields = None, - set_name=None, - expected_fields={"english": [float(i) for i in range(1024)]}, - timeout=None, - ), - ), - ( - {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": 1, "spanish": 2}))}, - get_test_case( - namespace=DEFAULT_NAMESPACE, - include_fields=["english"], - exclude_fields = None, - set_name=None, - expected_fields={"english": 1}, - timeout=None, - ), - ), - ( - {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": 1, "spanish": 2}))}, - get_test_case( - namespace=DEFAULT_NAMESPACE, - include_fields=None, - exclude_fields=["spanish"], - set_name=None, - expected_fields={"english": 1}, - timeout=None, - ), - ), - ( - {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": 1, "spanish": 2}))}, - get_test_case( - namespace=DEFAULT_NAMESPACE, - include_fields=["spanish"], - exclude_fields=["spanish"], - set_name=None, - expected_fields={}, - timeout=None, - ), - ), - ( - {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": 1, "spanish": 2}))}, - get_test_case( - namespace=DEFAULT_NAMESPACE, - include_fields=[], - exclude_fields=None, - set_name=None, - expected_fields={}, - timeout=None, - ), - ), - ( - {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": 1, "spanish": 2}))}, - get_test_case( - namespace=DEFAULT_NAMESPACE, - include_fields=None, - exclude_fields=[], - set_name=None, - expected_fields={"english": 1, "spanish": 2}, - timeout=None, - ), - ), - ], - indirect=["record"], -) -async def test_vector_get(session_vector_client, test_case, record): - result = await session_vector_client.get( - namespace=test_case.namespace, - key=record, - include_fields=test_case.include_fields, - exclude_fields=test_case.exclude_fields, - ) - assert result.key.namespace == test_case.namespace - if test_case.set_name == None: - test_case.set_name = "" - assert result.key.set == test_case.set_name - assert result.key.key == record - - assert result.fields == test_case.expected_fields - - -#@given(random_key=key_strategy()) -#@settings(max_examples=1, deadline=1000) -@pytest.mark.parametrize( - "test_case", - [ - get_test_case( - namespace="test", - include_fields=["skills"], - exclude_fields = None, - set_name=None, - expected_fields=None, - timeout=0.0001, - ), - ], -) -async def test_vector_get_timeout( - session_vector_client, test_case, random_key, with_latency -): - if not with_latency: - pytest.skip("Server latency too low to test timeout") - with pytest.raises(AVSServerError) as e_info: - for i in range(10): - result = await session_vector_client.get( - namespace=test_case.namespace, - key=random_key, - include_fields=test_case.include_fields, - exclude_fields=test_case.exclude_fields, - timeout=test_case.timeout, - ) - assert e_info.value.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED diff --git a/tests/standard/aio/test_vector_client_index_get_percent_unmerged.py b/tests/standard/aio/test_vector_client_index_get_percent_unmerged.py deleted file mode 100644 index ca00f910..00000000 --- a/tests/standard/aio/test_vector_client_index_get_percent_unmerged.py +++ /dev/null @@ -1,66 +0,0 @@ -import time - -import pytest - - -class index_get_percent_unmerged_test_case: - def __init__( - self, - *, - namespace, - timeout, - expected_unmerged_percent - ): - self.namespace = namespace - self.timeout = timeout - self.expected_unmerged_percent = expected_unmerged_percent - - -@pytest.mark.parametrize( - "records,test_case", - [ - ( - {"num_records": 1_000}, - index_get_percent_unmerged_test_case( - namespace="test", - timeout=None, - expected_unmerged_percent=0.0, - ) - ), - ( - {"num_records": 1_000}, - index_get_percent_unmerged_test_case( - namespace="test", - timeout=60, - expected_unmerged_percent=0.0, - ), - ), - # the 500 records won't be indexed until index_interval - # is hit so we should expect 500.0% unmerged - ( - {"num_records": 500}, - index_get_percent_unmerged_test_case( - namespace="test", - timeout=None, - expected_unmerged_percent=500.0, - ), - ) - ], - indirect=["records"], -) -async def test_client_index_get_percent_unmerged( - session_vector_client, - index, - records, - test_case, -): - # need some time for index stats to be counted server side - time.sleep(1) - - percent_unmerged = await session_vector_client.index_get_percent_unmerged( - namespace=test_case.namespace, - name=index, - timeout=test_case.timeout, - ) - - assert percent_unmerged >= test_case.expected_unmerged_percent diff --git a/tests/standard/aio/test_vector_client_insert.py b/tests/standard/aio/test_vector_client_insert.py deleted file mode 100644 index 04e587cf..00000000 --- a/tests/standard/aio/test_vector_client_insert.py +++ /dev/null @@ -1,127 +0,0 @@ -import pytest -from aerospike_vector_search import AVSServerError -from utils import random_key, DEFAULT_NAMESPACE - -from hypothesis import given, settings, Verbosity -import asyncio -from hypothesis import given, settings - -import grpc - - -class insert_test_case: - def __init__( - self, - *, - namespace, - record_data, - set_name, - timeout, - ): - self.namespace = namespace - self.record_data = record_data - self.set_name = set_name - self.timeout = timeout - - -#@given(random_key=key_strategy()) -#@settings(max_examples=1, deadline=1000) -@pytest.mark.parametrize( - "test_case", - [ - insert_test_case( - namespace=DEFAULT_NAMESPACE, - record_data={"math": [i for i in range(1024)]}, - set_name=None, - timeout=None, - ), - insert_test_case( - namespace=DEFAULT_NAMESPACE, - record_data={"homeSkills": [float(i) for i in range(1024)]}, - set_name=None, - timeout=None, - ), - insert_test_case( - namespace=DEFAULT_NAMESPACE, - record_data={"english": [bool(i) for i in range(1024)]}, - set_name=None, - timeout=None, - ), - ], -) -async def test_vector_insert_without_existing_record( - session_vector_client, test_case, random_key -): - await session_vector_client.delete( - namespace=test_case.namespace, - key=random_key, - ) - await session_vector_client.insert( - namespace=test_case.namespace, - key=random_key, - record_data=test_case.record_data, - set_name=test_case.set_name, - ) - await session_vector_client.delete( - namespace=test_case.namespace, - key=random_key, - ) - - -#@given(random_key=key_strategy()) -#@settings(max_examples=1, deadline=1000) -@pytest.mark.parametrize( - "test_case", - [ - insert_test_case( - namespace=DEFAULT_NAMESPACE, - record_data={"math": [i for i in range(1024)]}, - set_name=None, - timeout=None, - ), - ], -) -async def test_vector_insert_with_existing_record( - session_vector_client, test_case, record -): - with pytest.raises(AVSServerError) as e_info: - await session_vector_client.insert( - namespace=test_case.namespace, - key=record, - record_data=test_case.record_data, - set_name=test_case.set_name, - ) - - -#@given(random_key=key_strategy()) -#@settings(max_examples=1, deadline=1000) -@pytest.mark.parametrize( - "test_case", - [ - insert_test_case( - namespace=DEFAULT_NAMESPACE, - record_data={"math": [i for i in range(1024)]}, - set_name=None, - timeout=0.0001, - ), - ], -) -async def test_vector_insert_timeout( - session_vector_client, test_case, random_key, with_latency -): - if not with_latency: - pytest.skip("Server latency too low to test timeout") - with pytest.raises(AVSServerError) as e_info: - for i in range(10): - await session_vector_client.insert( - namespace=test_case.namespace, - key=random_key, - record_data=test_case.record_data, - set_name=test_case.set_name, - timeout=test_case.timeout, - ) - await session_vector_client.delete( - namespace=test_case.namespace, - key=random_key, - ) - assert e_info.value.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED diff --git a/tests/standard/aio/test_vector_client_is_indexed.py b/tests/standard/aio/test_vector_client_is_indexed.py deleted file mode 100644 index f222d6b3..00000000 --- a/tests/standard/aio/test_vector_client_is_indexed.py +++ /dev/null @@ -1,49 +0,0 @@ -import pytest -import time - -from utils import DEFAULT_NAMESPACE -from .aio_utils import wait_for_index -from aerospike_vector_search import AVSServerError - -import grpc - - -async def test_vector_is_indexed( - session_admin_client, - session_vector_client, - index, - record -): - # wait for the record to be indexed - await wait_for_index( - admin_client=session_admin_client, - namespace=DEFAULT_NAMESPACE, - index=index - ) - - result = await session_vector_client.is_indexed( - namespace=DEFAULT_NAMESPACE, - key=record, - index_name=index, - ) - assert result is True - - -async def test_vector_is_indexed_timeout( - session_vector_client, with_latency, index, record -): - if not with_latency: - pytest.skip("Server latency too low to test timeout") - for _ in range(10): - try: - await session_vector_client.is_indexed( - namespace=DEFAULT_NAMESPACE, - key=record, - index_name=index, - timeout=0.0001, - ) - except AVSServerError as se: - if se.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: - assert se.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED - return - assert "In several attempts, the timeout did not happen" == "TEST FAIL" \ No newline at end of file diff --git a/tests/standard/aio/test_vector_client_search_by_key.py b/tests/standard/aio/test_vector_client_search_by_key.py deleted file mode 100644 index 7bd2698f..00000000 --- a/tests/standard/aio/test_vector_client_search_by_key.py +++ /dev/null @@ -1,446 +0,0 @@ -import numpy as np -import pytest - -from utils import DEFAULT_NAMESPACE -from .aio_utils import wait_for_index -from aerospike_vector_search import types - -INDEX = "sbk_index" -NAMESPACE = DEFAULT_NAMESPACE -DIMENSIONS = 3 -VEC_BIN = "asbkvec" -SET_NAME = "test_set" - - -class vector_search_by_key_test_case: - def __init__( - self, - *, - index_dimensions, - vector_field, - limit, - key, - key_namespace, - search_namespace, - include_fields, - exclude_fields, - key_set, - expected_results, - ): - self.index_dimensions = index_dimensions - self.vector_field = vector_field - self.limit = limit - self.key = key - self.search_namespace = search_namespace - self.include_fields = include_fields - self.exclude_fields = exclude_fields - self.key_set = key_set - self.expected_results = expected_results - self.key_namespace = key_namespace - - -@pytest.fixture(scope="module") -async def setup_index( - session_admin_client, -): - await session_admin_client.index_create( - namespace=DEFAULT_NAMESPACE, - name=INDEX, - vector_field=VEC_BIN, - dimensions=DIMENSIONS, - index_params=types.HnswParams( - batching_params=types.HnswBatchingParams( - # 10_000 is the minimum value, in order for the tests to run as - # fast as possible we set it to the minimum value so records are indexed - # quickly - index_interval=10_000, - ), - healer_params=types.HnswHealerParams( - # run the healer every second - # for fast indexing - schedule="* * * * * ?" - ) - ) - ) - - yield - - await session_admin_client.index_drop( - namespace=DEFAULT_NAMESPACE, - name=INDEX, - ) - - -@pytest.fixture(scope="module") -async def setup_records( - session_vector_client, -): - recs = { - "rec1": { - "bin": 1, - VEC_BIN: [1.0] * DIMENSIONS, - }, - 2: { - "bin": 2, - VEC_BIN: [2.0] * DIMENSIONS, - }, - bytes("rec5", "utf-8"): { - "bin": 5, - VEC_BIN: [5.0] * DIMENSIONS, - }, - } - - keys = [] - for key, record in recs.items(): - keys.append(key) - await session_vector_client.upsert( - namespace=DEFAULT_NAMESPACE, - key=key, - record_data=record, - ) - - # write some records for set tests - set_recs = { - "srec100": { - "bin": 100, - VEC_BIN: [100.0] * DIMENSIONS, - }, - "srec101": { - "bin": 101, - VEC_BIN: [101.0] * DIMENSIONS, - }, - } - - for key, record in set_recs.items(): - keys.append(key) - await session_vector_client.upsert( - namespace=DEFAULT_NAMESPACE, - key=key, - record_data=record, - set_name=SET_NAME, - ) - - yield - - for key in keys: - await session_vector_client.delete( - namespace=DEFAULT_NAMESPACE, - key=key, - ) - - for key in keys: - await session_vector_client.delete( - namespace=DEFAULT_NAMESPACE, - key=key, - set_name=SET_NAME, - ) - - -#@settings(max_examples=1, deadline=1000) -@pytest.mark.parametrize( - "test_case", - [ - # test string key - vector_search_by_key_test_case( - index_dimensions=3, - vector_field=VEC_BIN, - limit=2, - key="rec1", - key_namespace=DEFAULT_NAMESPACE, - search_namespace=DEFAULT_NAMESPACE, - include_fields=None, - exclude_fields=None, - key_set=None, - expected_results=[ - types.Neighbor( - key=types.Key( - namespace=DEFAULT_NAMESPACE, - set="", - key="rec1", - ), - fields={ - "bin": 1, - VEC_BIN: [1.0, 1.0, 1.0], - }, - distance=0.0, - ), - types.Neighbor( - key=types.Key( - namespace=DEFAULT_NAMESPACE, - set="", - key=2, - ), - fields={ - "bin": 2, - VEC_BIN: [2.0, 2.0, 2.0], - }, - distance=3.0, - ), - ], - ), - # test int key - vector_search_by_key_test_case( - index_dimensions=3, - vector_field=VEC_BIN, - limit=3, - key=2, - key_namespace=DEFAULT_NAMESPACE, - search_namespace=DEFAULT_NAMESPACE, - include_fields=["bin"], - exclude_fields=["bin"], - key_set=None, - expected_results=[ - types.Neighbor( - key=types.Key( - namespace=DEFAULT_NAMESPACE, - set="", - key=2, - ), - fields={}, - distance=0.0, - ), - types.Neighbor( - key=types.Key( - namespace=DEFAULT_NAMESPACE, - set="", - key="rec1", - ), - fields={}, - distance=3.0, - ), - types.Neighbor( - key=types.Key( - namespace=DEFAULT_NAMESPACE, - set="", - key=bytes("rec5", "utf-8"), - ), - fields={}, - distance=27.0, - ), - ], - ), - # test bytes key - vector_search_by_key_test_case( - index_dimensions=3, - vector_field=VEC_BIN, - limit=3, - key=bytes("rec5", "utf-8"), - key_namespace=DEFAULT_NAMESPACE, - search_namespace=DEFAULT_NAMESPACE, - include_fields=["bin"], - exclude_fields=["bin"], - key_set=None, - expected_results=[ - types.Neighbor( - key=types.Key( - namespace=DEFAULT_NAMESPACE, - set="", - key=bytes("rec5", "utf-8"), - ), - fields={}, - distance=0.0, - ), - types.Neighbor( - key=types.Key( - namespace=DEFAULT_NAMESPACE, - set="", - key=2, - ), - fields={}, - distance=27.0, - ), - types.Neighbor( - key=types.Key( - namespace=DEFAULT_NAMESPACE, - set="", - key="rec1", - ), - fields={}, - distance=48.0, - ), - ], - ), - # # test bytearray key - # # TODO: add a bytearray key case, bytearrays are not hashable - # # so this is not easily added. Leaving it for now. - # # vector_search_by_key_test_case( - # # index_name="field_filter", - # # index_dimensions=3, - # # vector_field=VEC_BIN, - # # limit=3, - # # key=bytearray("rec1", "utf-8"), - # # namespace=DEFAULT_NAMESPACE, - # # include_fields=["bin"], - # # exclude_fields=["bin"], - # # key_set=None, - # # record_data={ - # # bytearray("rec1", "utf-8"): { - # # "bin": 1, - # # VEC_BIN: [1.0, 1.0, 1.0], - # # }, - # # bytearray("rec1", "utf-8"): { - # # "bin": 2, - # # VEC_BIN: [2.0, 2.0, 2.0], - # # }, - # # }, - # # expected_results=[ - # # types.Neighbor( - # # key=types.Key( - # # namespace=DEFAULT_NAMESPACE, - # # set="", - # # key=2, - # # ), - # # fields={}, - # # distance=3.0, - # # ), - # # ], - # # ), - # test with set name - vector_search_by_key_test_case( - index_dimensions=3, - vector_field=VEC_BIN, - limit=2, - key="srec100", - key_namespace=DEFAULT_NAMESPACE, - search_namespace=DEFAULT_NAMESPACE, - include_fields=None, - exclude_fields=None, - key_set=SET_NAME, - expected_results=[ - types.Neighbor( - key=types.Key( - namespace=DEFAULT_NAMESPACE, - set=SET_NAME, - key="srec100", - ), - fields={ - "bin": 100, - VEC_BIN: [100.0] * DIMENSIONS, - }, - distance=0.0, - ), - types.Neighbor( - key=types.Key( - namespace=DEFAULT_NAMESPACE, - set=SET_NAME, - key="srec101", - ), - fields={ - "bin": 101, - VEC_BIN: [101.0] * DIMENSIONS, - }, - distance=3.0, - ), - ], - ), - ], -) -async def test_vector_search_by_key( - session_vector_client, - session_admin_client, - setup_index, - setup_records, - test_case, -): - await wait_for_index(session_admin_client, DEFAULT_NAMESPACE, INDEX) - - results = await session_vector_client.vector_search_by_key( - search_namespace=test_case.search_namespace, - index_name=INDEX, - key=test_case.key, - key_namespace=test_case.key_namespace, - vector_field=test_case.vector_field, - limit=test_case.limit, - key_set=test_case.key_set, - include_fields=test_case.include_fields, - exclude_fields=test_case.exclude_fields, - ) - - assert results == test_case.expected_results - - -async def test_vector_search_by_key_different_namespaces( - session_vector_client, - session_admin_client, -): - - await session_admin_client.index_create( - namespace="index_storage", - name="diff_ns_idx", - vector_field="vec", - dimensions=3, - index_params=types.HnswParams( - batching_params=types.HnswBatchingParams( - # 10_000 is the minimum value, in order for the tests to run as - # fast as possible we set it to the minimum value so records are indexed - # quickly - index_interval=10_000, - ), - healer_params=types.HnswHealerParams( - # run the healer every second - # for fast indexing - schedule="* * * * * ?" - ) - ) - ) - - await session_vector_client.upsert( - namespace="test", - key="search_by", - record_data={ - "bin": 1, - "vec": [1.0, 1.0, 1.0], - }, - ) - - await session_vector_client.upsert( - namespace="index_storage", - key="search_for", - record_data={ - "bin": 2, - "vec": [2.0, 2.0, 2.0], - }, - ) - - await wait_for_index(session_admin_client, "index_storage", "diff_ns_idx") - - results = await session_vector_client.vector_search_by_key( - search_namespace="index_storage", - index_name="diff_ns_idx", - key="search_by", - key_namespace="test", - vector_field="vec", - limit=1, - ) - - expected = [ - types.Neighbor( - key=types.Key( - namespace="index_storage", - set="", - key="search_for", - ), - fields={ - "bin": 2, - "vec": [2.0, 2.0, 2.0], - }, - distance=3.0, - ), - ] - - assert results == expected - - await session_vector_client.delete( - namespace="test", - key="search_by", - ) - - await session_vector_client.delete( - namespace="index_storage", - key="search_for", - ) - - await session_admin_client.index_drop( - namespace="index_storage", - name="diff_ns_idx", - ) \ No newline at end of file diff --git a/tests/standard/aio/test_vector_client_update.py b/tests/standard/aio/test_vector_client_update.py deleted file mode 100644 index fc5b5264..00000000 --- a/tests/standard/aio/test_vector_client_update.py +++ /dev/null @@ -1,117 +0,0 @@ -import pytest -from aerospike_vector_search import AVSServerError -from ...utils import random_key, DEFAULT_NAMESPACE - -from hypothesis import given, settings, Verbosity -import grpc - - -class update_test_case: - def __init__( - self, - *, - namespace, - record_data, - set_name, - timeout, - ): - self.namespace = namespace - self.record_data = record_data - self.set_name = set_name - self.timeout = timeout - - -#@given(random_key=key_strategy()) -#@settings(max_examples=1, deadline=1000) -@pytest.mark.parametrize( - "test_case", - [ - update_test_case( - namespace=DEFAULT_NAMESPACE, - record_data={"math": [i for i in range(1024)]}, - set_name=None, - timeout=None, - ), - update_test_case( - namespace=DEFAULT_NAMESPACE, - record_data={"english": [float(i) for i in range(1024)]}, - set_name=None, - timeout=None, - ), - update_test_case( - namespace=DEFAULT_NAMESPACE, - record_data={"english": [bool(i) for i in range(1024)]}, - set_name=None, - timeout=None, - ), - ], -) -async def test_vector_update_with_existing_record( - session_vector_client, test_case, record -): - await session_vector_client.update( - namespace=test_case.namespace, - key=record, - record_data=test_case.record_data, - set_name=test_case.set_name, - ) - - -#@given(random_key=key_strategy()) -#@settings(max_examples=1, deadline=1000) -@pytest.mark.parametrize( - "test_case", - [ - update_test_case( - namespace=DEFAULT_NAMESPACE, - record_data={"math": [i for i in range(1024)]}, - set_name=None, - timeout=None, - ), - ], -) -async def test_vector_update_without_existing_record( - session_vector_client, test_case, random_key -): - await session_vector_client.delete( - namespace=test_case.namespace, - key=random_key, - ) - with pytest.raises(AVSServerError) as e_info: - await session_vector_client.update( - namespace=test_case.namespace, - key=random_key, - record_data=test_case.record_data, - set_name=test_case.set_name, - ) - - -#@given(random_key=key_strategy()) -#@settings(max_examples=1, deadline=1000) -@pytest.mark.parametrize( - "test_case", - [ - update_test_case( - namespace=DEFAULT_NAMESPACE, - record_data={"math": [i for i in range(1024)]}, - set_name=None, - timeout=0.0001, - ), - ], -) -async def test_vector_update_timeout( - session_vector_client, test_case, random_key, with_latency -): - if not with_latency: - pytest.skip("Server latency too low to test timeout") - - with pytest.raises(AVSServerError) as e_info: - for i in range(10): - await session_vector_client.update( - namespace=test_case.namespace, - key=random_key, - record_data=test_case.record_data, - set_name=test_case.set_name, - timeout=test_case.timeout, - ) - assert e_info.value.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED diff --git a/tests/standard/aio/test_vector_client_upsert.py b/tests/standard/aio/test_vector_client_upsert.py deleted file mode 100644 index cac115b4..00000000 --- a/tests/standard/aio/test_vector_client_upsert.py +++ /dev/null @@ -1,153 +0,0 @@ -import pytest -from ...utils import random_key, DEFAULT_NAMESPACE - -from hypothesis import given, settings, Verbosity -import numpy as np - -from aerospike_vector_search import AVSServerError -import grpc - - -class upsert_test_case: - def __init__(self, *, namespace, record_data, set_name, timeout, key=None): - self.namespace = namespace - self.record_data = record_data - self.set_name = set_name - self.timeout = timeout - - self.key = key - - -#@given(random_key=key_strategy()) -#@settings(max_examples=1, deadline=1000) -@pytest.mark.parametrize( - "test_case", - [ - upsert_test_case( - namespace=DEFAULT_NAMESPACE, - record_data={"math": [i for i in range(1024)]}, - set_name=None, - timeout=None, - ), - upsert_test_case( - namespace=DEFAULT_NAMESPACE, - record_data={"english": [float(i) for i in range(1024)]}, - set_name=None, - timeout=None, - ), - upsert_test_case( - namespace=DEFAULT_NAMESPACE, - record_data={"english": [bool(i) for i in range(1024)]}, - set_name=None, - timeout=None, - ), - ], -) -async def test_vector_upsert_without_existing_record( - session_vector_client, test_case, random_key -): - await session_vector_client.upsert( - namespace=test_case.namespace, - key=random_key, - record_data=test_case.record_data, - set_name=test_case.set_name, - ) - - await session_vector_client.delete( - namespace=test_case.namespace, - key=random_key, - ) - - -#@given(random_key=key_strategy()) -#@settings(max_examples=1, deadline=1000) -@pytest.mark.parametrize( - "test_case", - [ - upsert_test_case( - namespace=DEFAULT_NAMESPACE, - record_data={"math": [i for i in range(1024)]}, - set_name=None, - timeout=None, - ), - ], -) -async def test_vector_upsert_with_existing_record( - session_vector_client, test_case, random_key -): - await session_vector_client.upsert( - namespace=test_case.namespace, - key=random_key, - record_data=test_case.record_data, - set_name=test_case.set_name, - ) - - await session_vector_client.delete( - namespace=test_case.namespace, - key=random_key, - ) - - -@pytest.mark.parametrize( - "test_case", - [ - upsert_test_case( - namespace=DEFAULT_NAMESPACE, - record_data={"math": [i for i in range(1024)]}, - set_name=None, - timeout=None, - key=np.int32(31), - ), - upsert_test_case( - namespace=DEFAULT_NAMESPACE, - record_data={"math": [i for i in range(1024)]}, - set_name=None, - timeout=None, - key=np.array([b"a", b"b", b"c"]), - ), - ], -) -async def test_vector_upsert_with_numpy_key(session_vector_client, test_case): - await session_vector_client.upsert( - namespace=test_case.namespace, - key=test_case.key, - record_data=test_case.record_data, - set_name=test_case.set_name, - ) - - await session_vector_client.delete( - namespace=test_case.namespace, - key=test_case.key, - ) - - -#@given(random_key=key_strategy()) -#@settings(max_examples=1, deadline=1000) -@pytest.mark.parametrize( - "test_case", - [ - None, - upsert_test_case( - namespace=DEFAULT_NAMESPACE, - record_data={"math": [i for i in range(1024)]}, - set_name=None, - timeout=0.0001, - ), - ], -) -async def test_vector_upsert_timeout( - session_vector_client, test_case, random_key, with_latency -): - if not with_latency: - pytest.skip("Server latency too low to test timeout") - - with pytest.raises(AVSServerError) as e_info: - for i in range(10): - await session_vector_client.upsert( - namespace=test_case.namespace, - key=random_key, - record_data=test_case.record_data, - set_name=test_case.set_name, - timeout=test_case.timeout, - ) - assert e_info.value.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED diff --git a/tests/standard/aio/test_vector_search.py b/tests/standard/aio/test_vector_search.py deleted file mode 100644 index ff597fe0..00000000 --- a/tests/standard/aio/test_vector_search.py +++ /dev/null @@ -1,172 +0,0 @@ -import numpy as np -import asyncio -import pytest - -from aerospike_vector_search import types -from utils import DEFAULT_NAMESPACE -from .aio_utils import wait_for_index - - -class vector_search_test_case: - def __init__( - self, - *, - index_name, - index_dimensions, - vector_field, - limit, - query, - namespace, - include_fields, - exclude_fields, - set_name, - record_data, - expected_results, - ): - self.index_name = index_name - self.index_dimensions = index_dimensions - self.vector_field = vector_field - self.limit = limit - self.query = query - self.namespace = namespace - self.include_fields = include_fields - self.exclude_fields = exclude_fields - self.set_name = set_name - self.record_data = record_data - self.expected_results = expected_results - - -# TODO add a teardown -#@settings(max_examples=1, deadline=1000) -@pytest.mark.parametrize( - "test_case", - [ - vector_search_test_case( - index_name="basic_search", - index_dimensions=3, - vector_field="vecs", - limit=3, - query=[0.0, 0.0, 0.0], - namespace=DEFAULT_NAMESPACE, - include_fields=None, - exclude_fields = None, - set_name=None, - record_data={ - "rec1": { - "bin1": 1, - "vecs": [1.0, 1.0, 1.0], - }, - }, - expected_results=[ - types.Neighbor( - key=types.Key( - namespace=DEFAULT_NAMESPACE, - set="", - key="rec1", - ), - fields={ - "bin1": 1, - "vecs": [1.0, 1.0, 1.0], - }, - distance=3.0, - ), - ], - ), - vector_search_test_case( - index_name="field_filter", - index_dimensions=3, - vector_field="vecs", - limit=3, - query=[0.0, 0.0, 0.0], - namespace=DEFAULT_NAMESPACE, - include_fields=["bin1"], - exclude_fields=["bin1"], - set_name=None, - record_data={ - "rec1": { - "bin1": 1, - "vecs": [1.0, 1.0, 1.0], - }, - }, - expected_results=[ - types.Neighbor( - key=types.Key( - namespace=DEFAULT_NAMESPACE, - set="", - key="rec1", - ), - fields={}, - distance=3.0, - ), - ], - ), - ], -) -async def test_vector_search( - session_vector_client, - session_admin_client, - test_case, -): - - await session_admin_client.index_create( - namespace=test_case.namespace, - name=test_case.index_name, - vector_field=test_case.vector_field, - dimensions=test_case.index_dimensions, - index_params=types.HnswParams( - batching_params=types.HnswBatchingParams( - # 10_000 is the minimum value, in order for the tests to run as - # fast as possible we set it to the minimum value so records are indexed - # quickly - index_interval=10_000, - ), - healer_params=types.HnswHealerParams( - # run the healer every second - # for fast indexing - schedule="* * * * * ?" - ) - ) - ) - - tasks = [] - for key, rec in test_case.record_data.items(): - tasks.append(session_vector_client.upsert( - namespace=test_case.namespace, - key=key, - record_data=rec, - set_name=test_case.set_name, - )) - - tasks.append( - wait_for_index( - session_admin_client, - namespace=test_case.namespace, - index=test_case.index_name, - ) - ) - await asyncio.gather(*tasks) - - results = await session_vector_client.vector_search( - namespace=test_case.namespace, - index_name=test_case.index_name, - query=test_case.query, - limit=test_case.limit, - include_fields=test_case.include_fields, - exclude_fields=test_case.exclude_fields, - ) - - assert results == test_case.expected_results - - tasks = [] - for key in test_case.record_data: - tasks.append(session_vector_client.delete( - namespace=test_case.namespace, - key=key, - )) - - await asyncio.gather(*tasks) - - await session_admin_client.index_drop( - namespace=test_case.namespace, - name=test_case.index_name, - ) diff --git a/tests/standard/conftest.py b/tests/standard/conftest.py index fd46e19c..1aa9ff31 100644 --- a/tests/standard/conftest.py +++ b/tests/standard/conftest.py @@ -1,5 +1,357 @@ +import asyncio +import random +import string + +from aerospike_vector_search import Client +from aerospike_vector_search.aio import Client as AsyncClient +from aerospike_vector_search.admin import Client as AdminClient +from aerospike_vector_search.aio.admin import Client as AsyncAdminClient +from aerospike_vector_search import types, AVSServerError + +from utils import gen_records, DEFAULT_NAMESPACE, DEFAULT_INDEX_DIMENSION, DEFAULT_VECTOR_FIELD +import grpc import pytest +########################################## +###### GLOBALS +########################################## + +# default test values +DEFAULT_INDEX_ARGS = { + "namespace": DEFAULT_NAMESPACE, + "vector_field": DEFAULT_VECTOR_FIELD, + "dimensions": DEFAULT_INDEX_DIMENSION, +} + +DEFAULT_RECORD_GENERATOR = gen_records +DEFAULT_NUM_RECORDS = 1000 +DEFAULT_RECORDS_ARGS = { + "record_generator": DEFAULT_RECORD_GENERATOR, + "namespace": DEFAULT_NAMESPACE, + "vector_field": DEFAULT_VECTOR_FIELD, + "dimensions": DEFAULT_INDEX_DIMENSION, + "num_records": DEFAULT_NUM_RECORDS, +} + + +########################################## +###### FIXTURES +########################################## + +@pytest.fixture(scope="module", autouse=True) +def drop_all_indexes( + username, + password, + root_certificate, + host, + port, + certificate_chain, + private_key, + is_loadbalancer, + ssl_target_name_override, +): + + if root_certificate: + with open(root_certificate, "rb") as f: + root_certificate = f.read() + + if certificate_chain: + with open(certificate_chain, "rb") as f: + certificate_chain = f.read() + if private_key: + with open(private_key, "rb") as f: + private_key = f.read() + + with AdminClient( + seeds=types.HostPort(host=host, port=port), + is_loadbalancer=is_loadbalancer, + username=username, + password=password, + root_certificate=root_certificate, + certificate_chain=certificate_chain, + private_key=private_key, + ssl_target_name_override=ssl_target_name_override, + ) as client: + index_list = client.index_list() + + tasks = [] + for item in index_list: + client.index_drop(namespace="test", name=item["id"]["name"]) + + +@pytest.fixture(scope="session") +def event_loop(): + """ + Create an event loop for the test session. + """ + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + #HACK the async client schedules tasks in its init function so we need + # to run the event loop to allow the tasks to be scheduled + yield loop + loop.close() + + +class AsyncClientWrapper(): + def __init__(self, client): + self.client = client + + def __getattr__(self, name): + attr = getattr(self.client, name) + if asyncio.iscoroutinefunction(attr): + # Wrap async methods to run in the current event loop + def sync_method(*args, **kwargs): + loop = asyncio.get_event_loop() + return loop.run_until_complete(attr(*args, **kwargs)) + + return sync_method + return attr + + +async def new_wrapped_async_client( + host, + port, + username, + password, + root_certificate, + certificate_chain, + private_key, + is_loadbalancer, + ssl_target_name_override +): + client = AsyncClient( + seeds=types.HostPort(host=host, port=port), + is_loadbalancer=is_loadbalancer, + username=username, + password=password, + root_certificate=root_certificate, + certificate_chain=certificate_chain, + private_key=private_key, + ssl_target_name_override=ssl_target_name_override + ) + return AsyncClientWrapper(client) + + +async def new_wrapped_async_admin_client( + host, + port, + username, + password, + root_certificate, + certificate_chain, + private_key, + is_loadbalancer, + ssl_target_name_override +): + client = AsyncAdminClient( + seeds=types.HostPort(host=host, port=port), + is_loadbalancer=is_loadbalancer, + username=username, + password=password, + root_certificate=root_certificate, + certificate_chain=certificate_chain, + private_key=private_key, + ssl_target_name_override=ssl_target_name_override + ) + return AsyncClientWrapper(client) + + +@pytest.fixture(scope="module") +def session_admin_client( + username, + password, + root_certificate, + host, + port, + certificate_chain, + private_key, + is_loadbalancer, + ssl_target_name_override, + async_client, + event_loop, +): + + if root_certificate: + with open(root_certificate, "rb") as f: + root_certificate = f.read() + + if certificate_chain: + with open(certificate_chain, "rb") as f: + certificate_chain = f.read() + if private_key: + with open(private_key, "rb") as f: + private_key = f.read() + + if async_client: + loop = asyncio.get_event_loop() + client = loop.run_until_complete(new_wrapped_async_admin_client( + host=host, + port=port, + username=username, + password=password, + root_certificate=root_certificate, + certificate_chain=certificate_chain, + private_key=private_key, + is_loadbalancer=is_loadbalancer, + ssl_target_name_override=ssl_target_name_override + )) + else: + client = AdminClient( + seeds=types.HostPort(host=host, port=port), + is_loadbalancer=is_loadbalancer, + username=username, + password=password, + root_certificate=root_certificate, + certificate_chain=certificate_chain, + private_key=private_key, + ssl_target_name_override=ssl_target_name_override + ) + + yield client + client.close() + + +@pytest.fixture(scope="module") +def session_vector_client( + username, + password, + root_certificate, + host, + port, + certificate_chain, + private_key, + is_loadbalancer, + ssl_target_name_override, + async_client, + event_loop, +): + + if root_certificate: + with open(root_certificate, "rb") as f: + root_certificate = f.read() + + if certificate_chain: + with open(certificate_chain, "rb") as f: + certificate_chain = f.read() + if private_key: + with open(private_key, "rb") as f: + private_key = f.read() + + if async_client: + loop = asyncio.get_event_loop() + client = loop.run_until_complete(new_wrapped_async_client( + host=host, + port=port, + username=username, + password=password, + root_certificate=root_certificate, + certificate_chain=certificate_chain, + private_key=private_key, + is_loadbalancer=is_loadbalancer, + ssl_target_name_override=ssl_target_name_override + )) + else: + client = Client( + seeds=types.HostPort(host=host, port=port), + is_loadbalancer=is_loadbalancer, + username=username, + password=password, + root_certificate=root_certificate, + certificate_chain=certificate_chain, + private_key=private_key, + ssl_target_name_override=ssl_target_name_override + ) + + yield client + client.close() + + +@pytest.fixture() +def index_name(): + length = random.randint(1, 15) + return ''.join(random.choices(string.ascii_letters + string.digits, k=length)) + + +@pytest.fixture(params=[DEFAULT_INDEX_ARGS]) +def index(session_admin_client, index_name, request): + args = request.param + namespace = args.get("namespace", DEFAULT_NAMESPACE) + vector_field = args.get("vector_field", DEFAULT_VECTOR_FIELD) + dimensions = args.get("dimensions", DEFAULT_INDEX_DIMENSION) + session_admin_client.index_create( + name = index_name, + namespace = namespace, + vector_field = vector_field, + dimensions = dimensions, + index_params=types.HnswParams( + batching_params=types.HnswBatchingParams( + # 10_000 is the minimum value, in order for the tests to run as + # fast as possible we set it to the minimum value so records are indexed + # quickly + index_interval=10_000, + ), + healer_params=types.HnswHealerParams( + # run the healer every second + # for fast indexing + schedule="* * * * * ?" + ) + ) + ) + yield index_name + try: + session_admin_client.index_drop(namespace=namespace, name=index_name) + except AVSServerError as se: + if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: + pass + else: + raise + + +@pytest.fixture(params=[DEFAULT_RECORDS_ARGS]) +def records(session_vector_client, request): + args = request.param + record_generator = args.get("record_generator", DEFAULT_RECORD_GENERATOR) + namespace = args.get("namespace", DEFAULT_NAMESPACE) + num_records = args.get("num_records", DEFAULT_NUM_RECORDS) + vector_field = args.get("vector_field", DEFAULT_VECTOR_FIELD) + dimensions = args.get("dimensions", DEFAULT_INDEX_DIMENSION) + set_name = args.get("set_name", None) + keys = [] + for key, rec in record_generator(count=num_records, vec_bin=vector_field, vec_dim=dimensions): + session_vector_client.upsert( + namespace=namespace, + key=key, + record_data=rec, + set_name=set_name, + ) + keys.append(key) + yield keys + for key in keys: + session_vector_client.delete(key=key, namespace=namespace) + + +@pytest.fixture(params=[DEFAULT_RECORDS_ARGS]) +def record(session_vector_client, request): + args = request.param + record_generator = args.get("record_generator", DEFAULT_RECORD_GENERATOR) + namespace = args.get("namespace", DEFAULT_NAMESPACE) + vector_field = args.get("vector_field", DEFAULT_VECTOR_FIELD) + dimensions = args.get("dimensions", DEFAULT_INDEX_DIMENSION) + set_name = args.get("set_name", None) + key, rec = next(record_generator(count=1, vec_bin=vector_field, vec_dim=dimensions)) + session_vector_client.upsert( + namespace=namespace, + key=key, + record_data=rec, + set_name=set_name, + ) + yield key + session_vector_client.delete(key=key, namespace=namespace) + + +########################################## +###### SUITE FLAGS +########################################## def pytest_addoption(parser): parser.addoption("--username", action="store", default=None, help="AVS Username") @@ -39,6 +391,16 @@ def pytest_addoption(parser): action="store_true", help="Run extensive vector search testing", ) + parser.addoption( + "--async", + action="store_true", + help="Run tests using the async client", + ) + + +@pytest.fixture(scope="module", autouse=True) +def async_client(request): + return request.config.getoption("--async") @pytest.fixture(scope="module", autouse=True) diff --git a/tests/standard/sync/requirements.txt b/tests/standard/requirements.txt similarity index 100% rename from tests/standard/sync/requirements.txt rename to tests/standard/requirements.txt diff --git a/tests/standard/sync/__init__.py b/tests/standard/sync/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/standard/sync/conftest.py b/tests/standard/sync/conftest.py deleted file mode 100644 index d6ffff5b..00000000 --- a/tests/standard/sync/conftest.py +++ /dev/null @@ -1,271 +0,0 @@ -import pytest -import random -import string - -from aerospike_vector_search import Client -from aerospike_vector_search.admin import Client as AdminClient -from aerospike_vector_search import types, AVSServerError - -from .sync_utils import gen_records -import grpc - -#import logging -#logger = logging.getLogger(__name__) -#logging.basicConfig(filename='example.log', encoding='utf-8', level=logging.DEBUG) - -# default test values -DEFAULT_NAMESPACE = "test" -DEFAULT_INDEX_DIMENSION = 128 -DEFAULT_VECTOR_FIELD = "vector" -DEFAULT_INDEX_ARGS = { - "namespace": DEFAULT_NAMESPACE, - "vector_field": DEFAULT_VECTOR_FIELD, - "dimensions": DEFAULT_INDEX_DIMENSION, -} - -DEFAULT_RECORD_GENERATOR = gen_records -DEFAULT_NUM_RECORDS = 1000 -DEFAULT_RECORDS_ARGS = { - "record_generator": DEFAULT_RECORD_GENERATOR, - "namespace": DEFAULT_NAMESPACE, - "vector_field": DEFAULT_VECTOR_FIELD, - "dimensions": DEFAULT_INDEX_DIMENSION, - "num_records": DEFAULT_NUM_RECORDS, -} - -@pytest.fixture(scope="module", autouse=True) -def drop_all_indexes( - username, - password, - root_certificate, - host, - port, - certificate_chain, - private_key, - is_loadbalancer, - ssl_target_name_override, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - with AdminClient( - seeds=types.HostPort(host=host, port=port), - is_loadbalancer=is_loadbalancer, - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override, - ) as client: - index_list = client.index_list() - - tasks = [] - for item in index_list: - client.index_drop(namespace="test", name=item["id"]["name"]) - - -@pytest.fixture(scope="module") -def session_admin_client( - username, - password, - root_certificate, - host, - port, - certificate_chain, - private_key, - is_loadbalancer, - ssl_target_name_override -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - client = AdminClient( - seeds=types.HostPort(host=host, port=port), - is_loadbalancer=is_loadbalancer, - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override - ) - yield client - client.close() - - -@pytest.fixture(scope="module") -def session_vector_client( - username, - password, - root_certificate, - host, - port, - certificate_chain, - private_key, - is_loadbalancer, - ssl_target_name_override -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - client = Client( - seeds=types.HostPort(host=host, port=port), - is_loadbalancer=is_loadbalancer, - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override - ) - yield client - client.close() - - -@pytest.fixture -def function_admin_client( - username, - password, - root_certificate, - host, - port, - certificate_chain, - private_key, - is_loadbalancer, - ssl_target_name_override -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - client = AdminClient( - seeds=types.HostPort(host=host, port=port), - is_loadbalancer=is_loadbalancer, - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override - ) - yield client - client.close() - - -@pytest.fixture() -def index_name(): - length = random.randint(1, 15) - return ''.join(random.choices(string.ascii_letters + string.digits, k=length)) - - -@pytest.fixture(params=[DEFAULT_INDEX_ARGS]) -def index(session_admin_client, index_name, request): - args = request.param - namespace = args.get("namespace", DEFAULT_NAMESPACE) - vector_field = args.get("vector_field", DEFAULT_VECTOR_FIELD) - dimensions = args.get("dimensions", DEFAULT_INDEX_DIMENSION) - session_admin_client.index_create( - name = index_name, - namespace = namespace, - vector_field = vector_field, - dimensions = dimensions, - index_params=types.HnswParams( - batching_params=types.HnswBatchingParams( - # 10_000 is the minimum value, in order for the tests to run as - # fast as possible we set it to the minimum value so records are indexed - # quickly - index_interval=10_000, - ), - healer_params=types.HnswHealerParams( - # run the healer every second - # for fast indexing - schedule="* * * * * ?" - ) - ) - ) - yield index_name - try: - session_admin_client.index_drop(namespace=namespace, name=index_name) - except AVSServerError as se: - if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: - pass - else: - raise - - -@pytest.fixture(params=[DEFAULT_RECORDS_ARGS]) -def records(session_vector_client, request): - args = request.param - record_generator = args.get("record_generator", DEFAULT_RECORD_GENERATOR) - namespace = args.get("namespace", DEFAULT_NAMESPACE) - num_records = args.get("num_records", DEFAULT_NUM_RECORDS) - vector_field = args.get("vector_field", DEFAULT_VECTOR_FIELD) - dimensions = args.get("dimensions", DEFAULT_INDEX_DIMENSION) - set_name = args.get("set_name", None) - keys = [] - for key, rec in record_generator(count=num_records, vec_bin=vector_field, vec_dim=dimensions): - session_vector_client.upsert( - namespace=namespace, - key=key, - record_data=rec, - set_name=set_name, - ) - keys.append(key) - yield keys - for key in keys: - session_vector_client.delete(key=key, namespace=namespace) - - -@pytest.fixture(params=[DEFAULT_RECORDS_ARGS]) -def record(session_vector_client, request): - args = request.param - record_generator = args.get("record_generator", DEFAULT_RECORD_GENERATOR) - namespace = args.get("namespace", DEFAULT_NAMESPACE) - vector_field = args.get("vector_field", DEFAULT_VECTOR_FIELD) - dimensions = args.get("dimensions", DEFAULT_INDEX_DIMENSION) - set_name = args.get("set_name", None) - key, rec = next(record_generator(count=1, vec_bin=vector_field, vec_dim=dimensions)) - session_vector_client.upsert( - namespace=namespace, - key=key, - record_data=rec, - set_name=set_name, - ) - yield key - session_vector_client.delete(key=key, namespace=namespace) \ No newline at end of file diff --git a/tests/standard/sync/sync_utils.py b/tests/standard/sync/sync_utils.py deleted file mode 100644 index 05aeeb94..00000000 --- a/tests/standard/sync/sync_utils.py +++ /dev/null @@ -1,34 +0,0 @@ -import time - -def drop_specified_index(admin_client, namespace, name): - admin_client.index_drop(namespace=namespace, name=name) - - -def gen_records(count: int, vec_bin: str, vec_dim: int): - num = 0 - while num < count: - key_and_rec = ( - num, - { "id": num, vec_bin: [float(num)] * vec_dim} - ) - yield key_and_rec - num += 1 - - -def wait_for_index(admin_client, namespace: str, index: str): - - verticies = 0 - unmerged_recs = 0 - - while verticies == 0 or unmerged_recs > 0: - status = admin_client.index_get_status( - namespace=namespace, - name=index, - ) - - verticies = status.index_healer_vertices_valid - unmerged_recs = status.unmerged_record_count - - # print(verticies) - # print(unmerged_recs) - time.sleep(0.5) \ No newline at end of file diff --git a/tests/standard/sync/test_admin_client_index_create.py b/tests/standard/test_admin_client_index_create.py similarity index 99% rename from tests/standard/sync/test_admin_client_index_create.py rename to tests/standard/test_admin_client_index_create.py index 3b90fdba..d869506f 100644 --- a/tests/standard/sync/test_admin_client_index_create.py +++ b/tests/standard/test_admin_client_index_create.py @@ -2,8 +2,7 @@ import grpc from aerospike_vector_search import types, AVSServerError -from ...utils import random_name, DEFAULT_NAMESPACE -from .sync_utils import drop_specified_index +from utils import random_name, drop_specified_index, DEFAULT_NAMESPACE from hypothesis import given, settings, Verbosity diff --git a/tests/standard/sync/test_admin_client_index_drop.py b/tests/standard/test_admin_client_index_drop.py similarity index 100% rename from tests/standard/sync/test_admin_client_index_drop.py rename to tests/standard/test_admin_client_index_drop.py diff --git a/tests/standard/sync/test_admin_client_index_get.py b/tests/standard/test_admin_client_index_get.py similarity index 98% rename from tests/standard/sync/test_admin_client_index_get.py rename to tests/standard/test_admin_client_index_get.py index b8034a71..c6a77e00 100644 --- a/tests/standard/sync/test_admin_client_index_get.py +++ b/tests/standard/test_admin_client_index_get.py @@ -1,4 +1,4 @@ -from ...utils import DEFAULT_NAMESPACE, DEFAULT_INDEX_DIMENSION, DEFAULT_VECTOR_FIELD +from utils import DEFAULT_NAMESPACE, DEFAULT_INDEX_DIMENSION, DEFAULT_VECTOR_FIELD from aerospike_vector_search import AVSServerError import grpc diff --git a/tests/standard/sync/test_admin_client_index_get_status.py b/tests/standard/test_admin_client_index_get_status.py similarity index 94% rename from tests/standard/sync/test_admin_client_index_get_status.py rename to tests/standard/test_admin_client_index_get_status.py index 0f127829..93fa8c83 100644 --- a/tests/standard/sync/test_admin_client_index_get_status.py +++ b/tests/standard/test_admin_client_index_get_status.py @@ -1,9 +1,9 @@ import pytest import grpc -from ...utils import DEFAULT_NAMESPACE +from utils import DEFAULT_NAMESPACE -from .sync_utils import drop_specified_index +from utils import drop_specified_index from hypothesis import given, settings, Verbosity from aerospike_vector_search import AVSServerError diff --git a/tests/standard/sync/test_admin_client_index_list.py b/tests/standard/test_admin_client_index_list.py similarity index 97% rename from tests/standard/sync/test_admin_client_index_list.py rename to tests/standard/test_admin_client_index_list.py index 4a119573..55e3978b 100644 --- a/tests/standard/sync/test_admin_client_index_list.py +++ b/tests/standard/test_admin_client_index_list.py @@ -1,5 +1,5 @@ from aerospike_vector_search import AVSServerError -from .sync_utils import drop_specified_index +from utils import drop_specified_index import pytest import grpc diff --git a/tests/standard/sync/test_admin_client_index_update.py b/tests/standard/test_admin_client_index_update.py similarity index 100% rename from tests/standard/sync/test_admin_client_index_update.py rename to tests/standard/test_admin_client_index_update.py diff --git a/tests/standard/sync/test_extensive_vector_search.py b/tests/standard/test_extensive_vector_search.py similarity index 100% rename from tests/standard/sync/test_extensive_vector_search.py rename to tests/standard/test_extensive_vector_search.py diff --git a/tests/standard/sync/test_service_config.py b/tests/standard/test_service_config.py similarity index 98% rename from tests/standard/sync/test_service_config.py rename to tests/standard/test_service_config.py index b22f1088..1b530d71 100644 --- a/tests/standard/sync/test_service_config.py +++ b/tests/standard/test_service_config.py @@ -1,3 +1,8 @@ +# TODO refactor this file so it uses the client fixtures +# and so that it is less timing dependent +# ideally these would be unit tests maybe with mocked out grpc channels +# that ensure the correct config is recieved + # import pytest # import time diff --git a/tests/standard/sync/test_vector_client_delete.py b/tests/standard/test_vector_client_delete.py similarity index 100% rename from tests/standard/sync/test_vector_client_delete.py rename to tests/standard/test_vector_client_delete.py diff --git a/tests/standard/sync/test_vector_client_exists.py b/tests/standard/test_vector_client_exists.py similarity index 100% rename from tests/standard/sync/test_vector_client_exists.py rename to tests/standard/test_vector_client_exists.py diff --git a/tests/standard/sync/test_vector_client_get.py b/tests/standard/test_vector_client_get.py similarity index 100% rename from tests/standard/sync/test_vector_client_get.py rename to tests/standard/test_vector_client_get.py diff --git a/tests/standard/sync/test_vector_client_index_get_percent_unmerged.py b/tests/standard/test_vector_client_index_get_percent_unmerged.py similarity index 100% rename from tests/standard/sync/test_vector_client_index_get_percent_unmerged.py rename to tests/standard/test_vector_client_index_get_percent_unmerged.py diff --git a/tests/standard/sync/test_vector_client_insert.py b/tests/standard/test_vector_client_insert.py similarity index 100% rename from tests/standard/sync/test_vector_client_insert.py rename to tests/standard/test_vector_client_insert.py diff --git a/tests/standard/sync/test_vector_client_is_indexed.py b/tests/standard/test_vector_client_is_indexed.py similarity index 97% rename from tests/standard/sync/test_vector_client_is_indexed.py rename to tests/standard/test_vector_client_is_indexed.py index 0c4baa97..b2434bda 100644 --- a/tests/standard/sync/test_vector_client_is_indexed.py +++ b/tests/standard/test_vector_client_is_indexed.py @@ -2,7 +2,7 @@ from aerospike_vector_search import AVSServerError from utils import random_name, DEFAULT_NAMESPACE -from .sync_utils import wait_for_index +from utils import wait_for_index import grpc diff --git a/tests/standard/sync/test_vector_client_search_by_key.py b/tests/standard/test_vector_client_search_by_key.py similarity index 98% rename from tests/standard/sync/test_vector_client_search_by_key.py rename to tests/standard/test_vector_client_search_by_key.py index 044a0ba0..79528e24 100644 --- a/tests/standard/sync/test_vector_client_search_by_key.py +++ b/tests/standard/test_vector_client_search_by_key.py @@ -1,6 +1,7 @@ -import pytest from aerospike_vector_search import types +from utils import wait_for_index +import pytest class vector_search_by_key_test_case: def __init__( @@ -324,9 +325,10 @@ def test_vector_search_by_key( set_name=test_case.key_set, ) - session_vector_client.wait_for_index_completion( + wait_for_index( + admin_client=session_admin_client, namespace=test_case.search_namespace, - name=test_case.index_name, + index=test_case.index_name, ) results = session_vector_client.vector_search_by_key( @@ -386,9 +388,10 @@ def test_vector_search_by_key_different_namespaces( }, ) - session_vector_client.wait_for_index_completion( + wait_for_index( + admin_client=session_admin_client, namespace="index_storage", - name="diff_ns_idx", + index="diff_ns_idx", ) results = session_vector_client.vector_search_by_key( diff --git a/tests/standard/sync/test_vector_client_update.py b/tests/standard/test_vector_client_update.py similarity index 99% rename from tests/standard/sync/test_vector_client_update.py rename to tests/standard/test_vector_client_update.py index c11fe643..25b4ad39 100644 --- a/tests/standard/sync/test_vector_client_update.py +++ b/tests/standard/test_vector_client_update.py @@ -2,7 +2,7 @@ import grpc from aerospike_vector_search import AVSServerError -from ...utils import random_key +from utils import random_key from hypothesis import given, settings, Verbosity diff --git a/tests/standard/sync/test_vector_client_upsert.py b/tests/standard/test_vector_client_upsert.py similarity index 99% rename from tests/standard/sync/test_vector_client_upsert.py rename to tests/standard/test_vector_client_upsert.py index 82d1f3ba..359a4455 100644 --- a/tests/standard/sync/test_vector_client_upsert.py +++ b/tests/standard/test_vector_client_upsert.py @@ -1,5 +1,5 @@ import pytest -from ...utils import random_key +from utils import random_key from hypothesis import given, settings, Verbosity diff --git a/tests/standard/sync/test_vector_search.py b/tests/standard/test_vector_search.py similarity index 96% rename from tests/standard/sync/test_vector_search.py rename to tests/standard/test_vector_search.py index 415e038c..4a1a9291 100644 --- a/tests/standard/sync/test_vector_search.py +++ b/tests/standard/test_vector_search.py @@ -1,7 +1,7 @@ -import numpy as np -import pytest from aerospike_vector_search import types +from utils import wait_for_index +import pytest class vector_search_test_case: def __init__( @@ -118,9 +118,10 @@ def test_vector_search( set_name=test_case.set_name, ) - session_vector_client.wait_for_index_completion( + wait_for_index( + admin_client=session_admin_client, namespace=test_case.namespace, - name=test_case.index_name, + index=test_case.index_name, ) results = session_vector_client.vector_search( diff --git a/tests/utils.py b/tests/utils.py index ddcb1306..30bd7fe9 100755 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,8 +1,9 @@ import random -import hypothesis.strategies as st -from hypothesis import given +import time import string +import hypothesis.strategies as st +from hypothesis import given import pytest @@ -48,6 +49,40 @@ def random_key(): return key +def drop_specified_index(admin_client, namespace, name): + admin_client.index_drop(namespace=namespace, name=name) + + +def gen_records(count: int, vec_bin: str, vec_dim: int): + num = 0 + while num < count: + key_and_rec = ( + num, + { "id": num, vec_bin: [float(num)] * vec_dim} + ) + yield key_and_rec + num += 1 + + +def wait_for_index(admin_client, namespace: str, index: str): + + verticies = 0 + unmerged_recs = 0 + + while verticies == 0 or unmerged_recs > 0: + status = admin_client.index_get_status( + namespace=namespace, + name=index, + ) + + verticies = status.index_healer_vertices_valid + unmerged_recs = status.unmerged_record_count + + # print(verticies) + # print(unmerged_recs) + time.sleep(0.5) + + """ def key_strategy(): return st.text(alphabet=allowed_chars, min_size=1, max_size=100_000).filter( From 95fe8bd312f1c3858237bfd1e2f499dd7581278e Mon Sep 17 00:00:00 2001 From: dylan Date: Fri, 27 Dec 2024 15:13:53 -0800 Subject: [PATCH 10/21] add a --sync test flag --- tests/standard/conftest.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/standard/conftest.py b/tests/standard/conftest.py index 1aa9ff31..ac58f679 100644 --- a/tests/standard/conftest.py +++ b/tests/standard/conftest.py @@ -396,6 +396,11 @@ def pytest_addoption(parser): action="store_true", help="Run tests using the async client", ) + parser.addoption( + "--sync", + action="store_true", + help="Run tests using the sync client", + ) @pytest.fixture(scope="module", autouse=True) @@ -403,6 +408,11 @@ def async_client(request): return request.config.getoption("--async") +@pytest.fixture(scope="module", autouse=True) +def sync_client(request): + return request.config.getoption("--sync") + + @pytest.fixture(scope="module", autouse=True) def username(request): return request.config.getoption("--username") From 6d2b466fdadaa9595a0b09042c3cad51b0d28bf3 Mon Sep 17 00:00:00 2001 From: dylan Date: Fri, 27 Dec 2024 17:34:52 -0800 Subject: [PATCH 11/21] debug event loop --- tests/standard/conftest.py | 117 ++++++++++++++++++++++++++----------- 1 file changed, 83 insertions(+), 34 deletions(-) diff --git a/tests/standard/conftest.py b/tests/standard/conftest.py index ac58f679..342c4bd6 100644 --- a/tests/standard/conftest.py +++ b/tests/standard/conftest.py @@ -1,6 +1,7 @@ import asyncio import random import string +import threading from aerospike_vector_search import Client from aerospike_vector_search.aio import Client as AsyncClient @@ -79,33 +80,53 @@ def drop_all_indexes( client.index_drop(namespace="test", name=item["id"]["name"]) -@pytest.fixture(scope="session") +# @pytest.fixture(scope="session", autouse=True) +# def event_loop(): +# """ +# Create an event loop for the test session. +# The async client requires a running event loop. +# So we create and run one here. +# """ +# loop = asyncio.new_event_loop() +# asyncio.set_event_loop(loop) + +# def run_loop(loop): +# asyncio.set_event_loop(loop) +# loop.run_forever() + +# thr = threading.Thread(target=run_loop, args=(loop,), daemon=True) +# thr.start() + +# yield loop + +# loop.call_soon_threadsafe(loop.stop) +# thr.join() +# loop.close() + + +@pytest.fixture(scope="session", autouse=True) def event_loop(): """ - Create an event loop for the test session. + Create an event loop that runs in a separate thread. + The async client requires a running event loop at initialization. """ loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - #HACK the async client schedules tasks in its init function so we need - # to run the event loop to allow the tasks to be scheduled - yield loop - loop.close() + # Define the target function to run the loop + def run_loop(): + asyncio.set_event_loop(loop) + loop.run_forever() -class AsyncClientWrapper(): - def __init__(self, client): - self.client = client - - def __getattr__(self, name): - attr = getattr(self.client, name) - if asyncio.iscoroutinefunction(attr): - # Wrap async methods to run in the current event loop - def sync_method(*args, **kwargs): - loop = asyncio.get_event_loop() - return loop.run_until_complete(attr(*args, **kwargs)) + # Start the event loop in a background thread + loop_thread = threading.Thread(target=run_loop, daemon=True) + loop_thread.start() - return sync_method - return attr + yield loop + + # Stop the event loop and wait for the thread to finish + loop.call_soon_threadsafe(loop.stop) + loop_thread.join() + loop.close() async def new_wrapped_async_client( @@ -117,9 +138,10 @@ async def new_wrapped_async_client( certificate_chain, private_key, is_loadbalancer, - ssl_target_name_override + ssl_target_name_override, + loop ): - client = AsyncClient( + return AsyncClient( seeds=types.HostPort(host=host, port=port), is_loadbalancer=is_loadbalancer, username=username, @@ -129,7 +151,6 @@ async def new_wrapped_async_client( private_key=private_key, ssl_target_name_override=ssl_target_name_override ) - return AsyncClientWrapper(client) async def new_wrapped_async_admin_client( @@ -141,9 +162,10 @@ async def new_wrapped_async_admin_client( certificate_chain, private_key, is_loadbalancer, - ssl_target_name_override + ssl_target_name_override, + loop ): - client = AsyncAdminClient( + return AsyncAdminClient( seeds=types.HostPort(host=host, port=port), is_loadbalancer=is_loadbalancer, username=username, @@ -153,7 +175,30 @@ async def new_wrapped_async_admin_client( private_key=private_key, ssl_target_name_override=ssl_target_name_override ) - return AsyncClientWrapper(client) + + +class AsyncClientWrapper(): + def __init__(self, client, loop): + self.client = client + self.loop = loop + + def __getattr__(self, name): + attr = getattr(self.client, name) + if asyncio.iscoroutinefunction(attr): + # Wrap async methods to run in the current event loop + def sync_method(*args, **kwargs): + return self._run_async_task(attr(*args, **kwargs)) + + return sync_method + return attr + + def _run_async_task(self, task): + # Submit the coroutine to the loop and get its result + if self.loop.is_running(): + future = asyncio.run_coroutine_threadsafe(task, self.loop) + return future.result() + else: + raise RuntimeError("Event loop is not running") @pytest.fixture(scope="module") @@ -183,8 +228,7 @@ def session_admin_client( private_key = f.read() if async_client: - loop = asyncio.get_event_loop() - client = loop.run_until_complete(new_wrapped_async_admin_client( + task = new_wrapped_async_admin_client( host=host, port=port, username=username, @@ -193,8 +237,11 @@ def session_admin_client( certificate_chain=certificate_chain, private_key=private_key, is_loadbalancer=is_loadbalancer, - ssl_target_name_override=ssl_target_name_override - )) + ssl_target_name_override=ssl_target_name_override, + loop=event_loop + ) + client = asyncio.run_coroutine_threadsafe(task, event_loop).result() + client = AsyncClientWrapper(client, event_loop) else: client = AdminClient( seeds=types.HostPort(host=host, port=port), @@ -238,8 +285,7 @@ def session_vector_client( private_key = f.read() if async_client: - loop = asyncio.get_event_loop() - client = loop.run_until_complete(new_wrapped_async_client( + task = new_wrapped_async_client( host=host, port=port, username=username, @@ -248,8 +294,11 @@ def session_vector_client( certificate_chain=certificate_chain, private_key=private_key, is_loadbalancer=is_loadbalancer, - ssl_target_name_override=ssl_target_name_override - )) + ssl_target_name_override=ssl_target_name_override, + loop=event_loop + ) + client = asyncio.run_coroutine_threadsafe(task, event_loop).result() + client = AsyncClientWrapper(client, event_loop) else: client = Client( seeds=types.HostPort(host=host, port=port), From 00e0a6d7c1fd187e5bd7a779a88fd48f0c006a04 Mon Sep 17 00:00:00 2001 From: dylan Date: Fri, 27 Dec 2024 19:36:06 -0800 Subject: [PATCH 12/21] remove faulty test case --- .../test_vector_client_search_by_key.py | 41 ++++++------------- 1 file changed, 13 insertions(+), 28 deletions(-) diff --git a/tests/standard/test_vector_client_search_by_key.py b/tests/standard/test_vector_client_search_by_key.py index 79528e24..5ea3df98 100644 --- a/tests/standard/test_vector_client_search_by_key.py +++ b/tests/standard/test_vector_client_search_by_key.py @@ -261,34 +261,6 @@ def __init__( ), ], ), - # test search key record and search records are in different namespaces - vector_search_by_key_test_case( - index_name="basic_search", - index_dimensions=3, - vector_field="vector", - limit=2, - key="rec1", - key_namespace="test", - search_namespace="index_storage", - include_fields=None, - exclude_fields=None, - key_set=None, - record_data={ - "rec1": { - "bin": 1, - "vector": [1.0, 1.0, 1.0], - }, - "rec2": { - "bin": 2, - "vector": [2.0, 2.0, 2.0], - }, - "rec3": { - "bin": 3, - "vector": [3.0, 3.0, 3.0], - }, - }, - expected_results=[], - ), ], ) def test_vector_search_by_key( @@ -368,6 +340,19 @@ def test_vector_search_by_key_different_namespaces( name="diff_ns_idx", vector_field="vec", dimensions=3, + index_params=types.HnswParams( + batching_params=types.HnswBatchingParams( + # 10_000 is the minimum value, in order for the tests to run as + # fast as possible we set it to the minimum value so records are indexed + # quickly + index_interval=10_000, + ), + healer_params=types.HnswHealerParams( + # run the healer every second + # for fast indexing + schedule="* * * * * ?" + ) + ) ) session_vector_client.upsert( From 75580de70be2c3f72aac547fb8fb72b1a3cbba02 Mon Sep 17 00:00:00 2001 From: dylan Date: Mon, 30 Dec 2024 11:28:54 -0800 Subject: [PATCH 13/21] name sync and async test coverage files separately --- .github/workflows/integration_test.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/integration_test.yml b/.github/workflows/integration_test.yml index a8ea56f0..07019fef 100644 --- a/.github/workflows/integration_test.yml +++ b/.github/workflows/integration_test.yml @@ -17,7 +17,7 @@ jobs: strategy: matrix: python-version: ["3.9", "3.10", "3.11", "3.12"] - async: ["--sync", "--async"] + async: ["sync", "async"] steps: @@ -71,7 +71,7 @@ jobs: sleep 5 docker ps - python -m pytest standard -s --host 0.0.0.0 --port 5000 --cov=aerospike_vector_search ${{ matrix.async }} + python -m pytest standard -s --host 0.0.0.0 --port 5000 --cov=aerospike_vector_search --${{ matrix.async }} mv .coverage coverage_data working-directory: tests @@ -80,7 +80,7 @@ jobs: if: ${{ matrix.python-version == '3.12' }} uses: actions/upload-artifact@v4 with: - name: coverage_normal + name: coverage_normal_${{ matrix.async }} path: tests/coverage_data test-tls: From 943118fd9ddafe19a28878d324fcc81ca47860fc Mon Sep 17 00:00:00 2001 From: dylan Date: Mon, 30 Dec 2024 11:48:15 -0800 Subject: [PATCH 14/21] ci: print the 5 slowest integration tests --- .github/workflows/integration_test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration_test.yml b/.github/workflows/integration_test.yml index 07019fef..a08cc2dd 100644 --- a/.github/workflows/integration_test.yml +++ b/.github/workflows/integration_test.yml @@ -71,7 +71,7 @@ jobs: sleep 5 docker ps - python -m pytest standard -s --host 0.0.0.0 --port 5000 --cov=aerospike_vector_search --${{ matrix.async }} + python -m pytest standard -s --host 0.0.0.0 --port 5000 --cov=aerospike_vector_search --${{ matrix.async }} --durations=5 mv .coverage coverage_data working-directory: tests From f05bd41c031a31ae553b9b310e8d1e6f4ad14af3 Mon Sep 17 00:00:00 2001 From: dylan Date: Mon, 30 Dec 2024 13:07:41 -0800 Subject: [PATCH 15/21] run healer more frequently --- tests/standard/test_vector_search.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/standard/test_vector_search.py b/tests/standard/test_vector_search.py index 4a1a9291..a18414b4 100644 --- a/tests/standard/test_vector_search.py +++ b/tests/standard/test_vector_search.py @@ -108,6 +108,19 @@ def test_vector_search( name=test_case.index_name, vector_field=test_case.vector_field, dimensions=test_case.index_dimensions, + index_params=types.HnswParams( + batching_params=types.HnswBatchingParams( + # 10_000 is the minimum value, in order for the tests to run as + # fast as possible we set it to the minimum value so records are indexed + # quickly + index_interval=10_000, + ), + healer_params=types.HnswHealerParams( + # run the healer every second + # for fast indexing + schedule="* * * * * ?" + ) + ) ) for key, rec in test_case.record_data.items(): From 2496782fb54b08f644a1551b82f1d4d54b9b9a03 Mon Sep 17 00:00:00 2001 From: dylan Date: Mon, 30 Dec 2024 15:32:35 -0800 Subject: [PATCH 16/21] ci: delete service config tests --- tests/standard/test_service_config.py | 500 -------------------------- 1 file changed, 500 deletions(-) delete mode 100644 tests/standard/test_service_config.py diff --git a/tests/standard/test_service_config.py b/tests/standard/test_service_config.py deleted file mode 100644 index 1b530d71..00000000 --- a/tests/standard/test_service_config.py +++ /dev/null @@ -1,500 +0,0 @@ -# TODO refactor this file so it uses the client fixtures -# and so that it is less timing dependent -# ideally these would be unit tests maybe with mocked out grpc channels -# that ensure the correct config is recieved - -# import pytest -# import time - -# import os -# import json - -# from aerospike_vector_search import AVSServerError, types -# from aerospike_vector_search import AdminClient - - -# class service_config_parse_test_case: -# def __init__(self, *, service_config_path): -# self.service_config_path = service_config_path - - -# @pytest.mark.parametrize( -# "test_case", -# [ -# service_config_parse_test_case( -# service_config_path="service_configs/master.json" -# ), -# ], -# ) -# def test_admin_client_service_config_parse( -# host, -# port, -# username, -# password, -# root_certificate, -# certificate_chain, -# private_key, -# ssl_target_name_override, -# test_case, -# ): - -# if root_certificate: -# with open(root_certificate, "rb") as f: -# root_certificate = f.read() - -# if certificate_chain: -# with open(certificate_chain, "rb") as f: -# certificate_chain = f.read() -# if private_key: -# with open(private_key, "rb") as f: -# private_key = f.read() - -# with AdminClient( -# seeds=types.HostPort(host=host, port=port), -# username=username, -# password=password, -# root_certificate=root_certificate, -# certificate_chain=certificate_chain, -# private_key=private_key, -# service_config_path=test_case.service_config_path, -# ssl_target_name_override=ssl_target_name_override, -# ) as client: -# pass - - -# class service_config_test_case: -# def __init__( -# self, *, service_config_path, namespace, name, vector_field, dimensions -# ): - -# script_dir = os.path.dirname(os.path.abspath(__file__)) - -# self.service_config_path = os.path.abspath( -# os.path.join(script_dir, "..", "..", service_config_path) -# ) - -# with open(self.service_config_path, "rb") as f: -# self.service_config = json.load(f) - -# self.max_attempts = self.service_config["methodConfig"][0]["retryPolicy"][ -# "maxAttempts" -# ] -# self.initial_backoff = int( -# self.service_config["methodConfig"][0]["retryPolicy"]["initialBackoff"][:-1] -# ) -# self.max_backoff = int( -# self.service_config["methodConfig"][0]["retryPolicy"]["maxBackoff"][:-1] -# ) -# self.backoff_multiplier = self.service_config["methodConfig"][0]["retryPolicy"][ -# "backoffMultiplier" -# ] -# self.retryable_status_codes = self.service_config["methodConfig"][0][ -# "retryPolicy" -# ]["retryableStatusCodes"] -# self.namespace = namespace -# self.name = name -# self.vector_field = vector_field -# self.dimensions = dimensions - - -# def calculate_expected_time( -# max_attempts, -# initial_backoff, -# backoff_multiplier, -# max_backoff, -# retryable_status_codes, -# ): - -# current_backkoff = initial_backoff - -# expected_time = 0 -# for attempt in range(max_attempts - 1): -# expected_time += current_backkoff -# current_backkoff *= backoff_multiplier -# current_backkoff = min(current_backkoff, max_backoff) - -# return expected_time - - -# @pytest.mark.parametrize( -# "test_case", -# [ -# service_config_test_case( -# service_config_path="service_configs/retries.json", -# namespace="test", -# name="service_config_index_1", -# vector_field="example_1", -# dimensions=1024, -# ) -# ], -# ) -# def test_admin_client_service_config_retries( -# host, -# port, -# username, -# password, -# root_certificate, -# certificate_chain, -# private_key, -# ssl_target_name_override, -# test_case, -# ): - -# if root_certificate: -# with open(root_certificate, "rb") as f: -# root_certificate = f.read() - -# if certificate_chain: -# with open(certificate_chain, "rb") as f: -# certificate_chain = f.read() -# if private_key: -# with open(private_key, "rb") as f: -# private_key = f.read() - -# with AdminClient( -# seeds=types.HostPort(host=host, port=port), -# username=username, -# password=password, -# root_certificate=root_certificate, -# certificate_chain=certificate_chain, -# private_key=private_key, -# service_config_path=test_case.service_config_path, -# ssl_target_name_override=ssl_target_name_override, - -# ) as client: - -# try: -# client.index_create( -# namespace=test_case.namespace, -# name=test_case.name, -# vector_field=test_case.vector_field, -# dimensions=test_case.dimensions, -# ) -# except: -# pass -# expected_time = calculate_expected_time( -# test_case.max_attempts, -# test_case.initial_backoff, -# test_case.backoff_multiplier, -# test_case.max_backoff, -# test_case.retryable_status_codes, -# ) -# start_time = time.time() - -# with pytest.raises(AVSServerError) as e_info: -# client.index_create( -# namespace=test_case.namespace, -# name=test_case.name, -# vector_field=test_case.vector_field, -# dimensions=test_case.dimensions, -# ) - -# end_time = time.time() -# elapsed_time = end_time - start_time - -# assert abs(elapsed_time - expected_time) < 1.5 - - -# @pytest.mark.parametrize( -# "test_case", -# [ -# service_config_test_case( -# service_config_path="service_configs/initial_backoff.json", -# namespace="test", -# name="service_config_index_2", -# vector_field="example_1", -# dimensions=1024, -# ) -# ], -# ) -# def test_admin_client_service_config_initial_backoff( -# host, -# port, -# username, -# password, -# root_certificate, -# certificate_chain, -# private_key, -# ssl_target_name_override, -# test_case, -# ): - -# if root_certificate: -# with open(root_certificate, "rb") as f: -# root_certificate = f.read() - -# if certificate_chain: -# with open(certificate_chain, "rb") as f: -# certificate_chain = f.read() -# if private_key: -# with open(private_key, "rb") as f: -# private_key = f.read() - -# with AdminClient( -# seeds=types.HostPort(host=host, port=port), -# username=username, -# password=password, -# root_certificate=root_certificate, -# certificate_chain=certificate_chain, -# private_key=private_key, -# ssl_target_name_override=ssl_target_name_override, -# service_config_path=test_case.service_config_path, -# ) as client: - -# try: -# client.index_create( -# namespace=test_case.namespace, -# name=test_case.name, -# vector_field=test_case.vector_field, -# dimensions=test_case.dimensions, -# ) -# except: -# pass -# expected_time = calculate_expected_time( -# test_case.max_attempts, -# test_case.initial_backoff, -# test_case.backoff_multiplier, -# test_case.max_backoff, -# test_case.retryable_status_codes, -# ) -# start_time = time.time() - -# with pytest.raises(AVSServerError) as e_info: -# client.index_create( -# namespace=test_case.namespace, -# name=test_case.name, -# vector_field=test_case.vector_field, -# dimensions=test_case.dimensions, -# ) - -# end_time = time.time() -# elapsed_time = end_time - start_time - -# assert abs(elapsed_time - expected_time) < 1.5 - - -# @pytest.mark.parametrize( -# "test_case", -# [ -# service_config_test_case( -# service_config_path="service_configs/max_backoff.json", -# namespace="test", -# name="service_config_index_3", -# vector_field="example_1", -# dimensions=1024, -# ), -# service_config_test_case( -# service_config_path="service_configs/max_backoff_lower_than_initial.json", -# namespace="test", -# name="service_config_index_4", -# vector_field="example_1", -# dimensions=1024, -# ), -# ], -# ) -# def test_admin_client_service_config_max_backoff( -# host, -# port, -# username, -# password, -# root_certificate, -# certificate_chain, -# private_key, -# ssl_target_name_override, -# test_case, -# ): - -# if root_certificate: -# with open(root_certificate, "rb") as f: -# root_certificate = f.read() - -# if certificate_chain: -# with open(certificate_chain, "rb") as f: -# certificate_chain = f.read() -# if private_key: -# with open(private_key, "rb") as f: -# private_key = f.read() - -# with AdminClient( -# seeds=types.HostPort(host=host, port=port), -# username=username, -# password=password, -# root_certificate=root_certificate, -# certificate_chain=certificate_chain, -# private_key=private_key, -# ssl_target_name_override=ssl_target_name_override, -# service_config_path=test_case.service_config_path, -# ) as client: - -# try: -# client.index_create( -# namespace=test_case.namespace, -# name=test_case.name, -# vector_field=test_case.vector_field, -# dimensions=test_case.dimensions, -# ) -# except: -# pass -# expected_time = calculate_expected_time( -# test_case.max_attempts, -# test_case.initial_backoff, -# test_case.backoff_multiplier, -# test_case.max_backoff, -# test_case.retryable_status_codes, -# ) -# start_time = time.time() - -# with pytest.raises(AVSServerError) as e_info: -# client.index_create( -# namespace=test_case.namespace, -# name=test_case.name, -# vector_field=test_case.vector_field, -# dimensions=test_case.dimensions, -# ) - -# end_time = time.time() -# elapsed_time = end_time - start_time -# assert abs(elapsed_time - expected_time) < 1.5 - - -# @pytest.mark.parametrize( -# "test_case", -# [ -# service_config_test_case( -# service_config_path="service_configs/backoff_multiplier.json", -# namespace="test", -# name="service_config_index_5", -# vector_field="example_1", -# dimensions=1024, -# ) -# ], -# ) -# def test_admin_client_service_config_backoff_multiplier( -# host, -# port, -# username, -# password, -# root_certificate, -# certificate_chain, -# private_key, -# ssl_target_name_override, -# test_case, -# ): - -# if root_certificate: -# with open(root_certificate, "rb") as f: -# root_certificate = f.read() - -# if certificate_chain: -# with open(certificate_chain, "rb") as f: -# certificate_chain = f.read() -# if private_key: -# with open(private_key, "rb") as f: -# private_key = f.read() - -# with AdminClient( -# seeds=types.HostPort(host=host, port=port), -# username=username, -# password=password, -# root_certificate=root_certificate, -# certificate_chain=certificate_chain, -# private_key=private_key, -# ssl_target_name_override=ssl_target_name_override, -# service_config_path=test_case.service_config_path, -# ) as client: - -# try: -# client.index_create( -# namespace=test_case.namespace, -# name=test_case.name, -# vector_field=test_case.vector_field, -# dimensions=test_case.dimensions, -# ) -# except: -# pass -# expected_time = calculate_expected_time( -# test_case.max_attempts, -# test_case.initial_backoff, -# test_case.backoff_multiplier, -# test_case.max_backoff, -# test_case.retryable_status_codes, -# ) -# start_time = time.time() - -# with pytest.raises(AVSServerError) as e_info: -# client.index_create( -# namespace=test_case.namespace, -# name=test_case.name, -# vector_field=test_case.vector_field, -# dimensions=test_case.dimensions, -# ) - -# end_time = time.time() -# elapsed_time = end_time - start_time -# assert abs(elapsed_time - expected_time) < 1.5 - - -# @pytest.mark.parametrize( -# "test_case", -# [ -# service_config_test_case( -# service_config_path="service_configs/retryable_status_codes.json", -# namespace="test", -# name="service_config_index_6", -# vector_field=None, -# dimensions=None, -# ) -# ], -# ) -# def test_admin_client_service_config_retryable_status_codes( -# host, -# port, -# username, -# password, -# root_certificate, -# certificate_chain, -# private_key, -# ssl_target_name_override, -# test_case, -# ): - -# if root_certificate: -# with open(root_certificate, "rb") as f: -# root_certificate = f.read() - -# if certificate_chain: -# with open(certificate_chain, "rb") as f: -# certificate_chain = f.read() -# if private_key: -# with open(private_key, "rb") as f: -# private_key = f.read() - -# with AdminClient( -# seeds=types.HostPort(host=host, port=port), -# username=username, -# password=password, -# root_certificate=root_certificate, -# certificate_chain=certificate_chain, -# private_key=private_key, -# ssl_target_name_override=ssl_target_name_override, -# service_config_path=test_case.service_config_path, -# ) as client: - -# expected_time = calculate_expected_time( -# test_case.max_attempts, -# test_case.initial_backoff, -# test_case.backoff_multiplier, -# test_case.max_backoff, -# test_case.retryable_status_codes, -# ) -# start_time = time.time() - -# with pytest.raises(AVSServerError) as e_info: -# client.index_get_status( -# namespace=test_case.namespace, -# name=test_case.name, -# ) - -# end_time = time.time() -# elapsed_time = end_time - start_time -# assert abs(elapsed_time - expected_time) < 1.5 From 9f96ee000f19eee88949114a12aacac55756cdde Mon Sep 17 00:00:00 2001 From: dylan Date: Mon, 30 Dec 2024 15:43:20 -0800 Subject: [PATCH 17/21] remove commented code --- tests/standard/conftest.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/tests/standard/conftest.py b/tests/standard/conftest.py index 342c4bd6..f5590e64 100644 --- a/tests/standard/conftest.py +++ b/tests/standard/conftest.py @@ -80,30 +80,6 @@ def drop_all_indexes( client.index_drop(namespace="test", name=item["id"]["name"]) -# @pytest.fixture(scope="session", autouse=True) -# def event_loop(): -# """ -# Create an event loop for the test session. -# The async client requires a running event loop. -# So we create and run one here. -# """ -# loop = asyncio.new_event_loop() -# asyncio.set_event_loop(loop) - -# def run_loop(loop): -# asyncio.set_event_loop(loop) -# loop.run_forever() - -# thr = threading.Thread(target=run_loop, args=(loop,), daemon=True) -# thr.start() - -# yield loop - -# loop.call_soon_threadsafe(loop.stop) -# thr.join() -# loop.close() - - @pytest.fixture(scope="session", autouse=True) def event_loop(): """ From 8717a88e068e0b94ba150e687d7c30a7943763a4 Mon Sep 17 00:00:00 2001 From: dylan Date: Thu, 2 Jan 2025 14:55:49 -0800 Subject: [PATCH 18/21] refactor:! merged sync admin and non-admin clients --- docs/admin.rst | 10 - src/aerospike_vector_search/__init__.py | 1 - src/aerospike_vector_search/admin.py | 756 ------------------ src/aerospike_vector_search/client.py | 655 ++++++++++++++- .../shared/client_helpers.py | 6 - tests/rbac/sync/conftest.py | 7 +- tests/rbac/sync/test_admin_client_add_user.py | 12 +- .../rbac/sync/test_admin_client_drop_user.py | 8 +- tests/rbac/sync/test_admin_client_get_user.py | 6 +- .../sync/test_admin_client_grant_roles.py | 8 +- .../rbac/sync/test_admin_client_list_roles.py | 6 +- .../rbac/sync/test_admin_client_list_users.py | 6 +- .../sync/test_admin_client_revoke_roles.py | 8 +- .../test_admin_client_update_credentials.py | 8 +- tests/standard/conftest.py | 91 +-- .../test_admin_client_index_create.py | 74 +- .../standard/test_admin_client_index_drop.py | 10 +- tests/standard/test_admin_client_index_get.py | 12 +- .../test_admin_client_index_get_status.py | 10 +- .../standard/test_admin_client_index_list.py | 8 +- .../test_admin_client_index_update.py | 6 +- .../standard/test_extensive_vector_search.py | 27 +- .../standard/test_vector_client_is_indexed.py | 3 +- .../test_vector_client_search_by_key.py | 14 +- tests/standard/test_vector_search.py | 7 +- 25 files changed, 768 insertions(+), 991 deletions(-) delete mode 100644 docs/admin.rst delete mode 100644 src/aerospike_vector_search/admin.py diff --git a/docs/admin.rst b/docs/admin.rst deleted file mode 100644 index eea13706..00000000 --- a/docs/admin.rst +++ /dev/null @@ -1,10 +0,0 @@ -AdminClient -===================== - -This class is the admin client, designed to conduct AVS administrative operation such as creating indexes, querying index information, and dropping indexes. - - -.. autoclass:: aerospike_vector_search.admin.Client - :members: - :undoc-members: - :show-inheritance: diff --git a/src/aerospike_vector_search/__init__.py b/src/aerospike_vector_search/__init__.py index cc942ecb..d08c2b3b 100644 --- a/src/aerospike_vector_search/__init__.py +++ b/src/aerospike_vector_search/__init__.py @@ -1,5 +1,4 @@ from .client import Client -from .admin import Client as AdminClient from .types import ( HostPort, Key, diff --git a/src/aerospike_vector_search/admin.py b/src/aerospike_vector_search/admin.py deleted file mode 100644 index a9638d09..00000000 --- a/src/aerospike_vector_search/admin.py +++ /dev/null @@ -1,756 +0,0 @@ -import logging -import sys -import time -from typing import Optional, Union - -import grpc - -from . import types -from .internal import channel_provider -from .shared.admin_helpers import BaseClient -from .shared.conversions import fromIndexStatusResponse -from .types import IndexDefinition, Role - -logger = logging.getLogger(__name__) - - -class Client(BaseClient): - """ - Aerospike Vector Search Admin Client - - This client is designed to conduct Aerospike Vector Search administrative operation such as creating indexes, querying index information, and dropping indexes. - - :param seeds: Defines the AVS nodes to which you want AVS to connect. AVS iterates through the seed nodes. After connecting to a node, AVS discovers all the nodes in the cluster. - :type seeds: Union[types.HostPort, tuple[types.HostPort, ...]] - - :param listener_name: An external (NATed) address and port combination that differs from the actual address and port where AVS is listening. Clients can access AVS on a node using the advertised listener address and port. Defaults to None. - :type listener_name: Optional[str] - - :param is_loadbalancer: If true, the first seed address will be treated as a load balancer node. Defaults to False. - :type is_loadbalancer: Optional[bool] - - :param service_config_path: Path to the service configuration file. Defaults to None. - :type service_config_path: Optional[str] - - :param username: Username for Role-Based Access. Defaults to None. - :type username: Optional[str] - - :param password: Password for Role-Based Access. Defaults to None. - :type password: Optional[str] - - :param root_certificate: The PEM-encoded root certificates as a byte string. Defaults to None. - :type root_certificate: Optional[list[bytes], bytes] - - :param certificate_chain: The PEM-encoded certificate chain as a byte string. Defaults to None. - :type certificate_chain: Optional[bytes] - - :param private_key: The PEM-encoded private key as a byte string. Defaults to None. - :type private_key: Optional[bytes] - - :raises AVSClientError: Raised when no seed host is provided. - - """ - - def __init__( - self, - *, - seeds: Union[types.HostPort, tuple[types.HostPort, ...]], - listener_name: Optional[str] = None, - is_loadbalancer: Optional[bool] = False, - username: Optional[str] = None, - password: Optional[str] = None, - root_certificate: Optional[Union[list[str], str]] = None, - certificate_chain: Optional[str] = None, - private_key: Optional[str] = None, - service_config_path: Optional[str] = None, - ssl_target_name_override: Optional[str] = None, - ) -> None: - seeds = self._prepare_seeds(seeds) - - self._channel_provider = channel_provider.ChannelProvider( - seeds, - listener_name, - is_loadbalancer, - username, - password, - root_certificate, - certificate_chain, - private_key, - service_config_path, - ssl_target_name_override, - ) - - def index_create( - self, - *, - namespace: str, - name: str, - vector_field: str, - dimensions: int, - vector_distance_metric: types.VectorDistanceMetric = ( - types.VectorDistanceMetric.SQUARED_EUCLIDEAN - ), - sets: Optional[str] = None, - index_params: Optional[types.HnswParams] = None, - index_labels: Optional[dict[str, str]] = None, - index_storage: Optional[types.IndexStorage] = None, - timeout: Optional[int] = 100_000, - ) -> None: - """ - Create an index. - - :param namespace: The namespace for the index. - :type namespace: str - - :param name: The name of the index. - :type name: str - - :param vector_field: The name of the field containing vector data. - :type vector_field: str - - :param dimensions: The number of dimensions in the vector data. - :type dimensions: int - - :param vector_distance_metric: - The distance metric used to compare when performing a vector search. - Defaults to :class:`VectorDistanceMetric.SQUARED_EUCLIDEAN`. - :type vector_distance_metric: types.VectorDistanceMetric - - :param sets: The set used for the index. Defaults to None. - :type sets: Optional[str] - - :param index_params: (Optional[types.HnswParams], optional): Parameters used for tuning - vector search. Defaults to None. If index_params is None, then the default values - specified for :class:`types.HnswParams` will be used. - :type index_params: Optional[types.HnswParams] - - :param index_labels: Metadata associated with the index. Defaults to None. - :type index_labels: Optional[dict[str, str]] - - :param index_storage: Namespace and set where index overhead (non-vector data) is stored. - :type index_storage: Optional[types.IndexStorage] - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - Note: - This method creates an index with the specified parameters and waits for the index creation to complete. - It waits for up to 100,000 seconds for the index creation to complete. - """ - - (index_stub, index_create_request, kwargs) = self._prepare_index_create( - namespace, - name, - vector_field, - dimensions, - vector_distance_metric, - sets, - index_params, - index_labels, - index_storage, - timeout, - logger, - ) - - try: - index_stub.Create( - index_create_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to create index with error: %s", e) - raise types.AVSServerError(rpc_error=e) - try: - self._wait_for_index_creation( - namespace=namespace, name=name, timeout=100_000 - ) - except grpc.RpcError as e: - logger.error("Failed waiting for creation with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - def index_update( - self, - *, - namespace: str, - name: str, - index_labels: Optional[dict[str, str]] = None, - hnsw_update_params: Optional[types.HnswIndexUpdate] = None, - timeout: Optional[int] = 100_000, - ) -> None: - """ - Update an existing index. - - :param namespace: The namespace for the index. - :type namespace: str - - :param name: The name of the index. - :type name: str - - :param index_labels: Optional labels associated with the index. Defaults to None. - :type index_labels: Optional[dict[str, str]] - - :param hnsw_update_params: Parameters for updating HNSW index settings. - :type hnsw_update_params: Optional[types.HnswIndexUpdate] - - :param timeout: Time in seconds (default 100_000) this operation will wait before raising an error. - :type timeout: int - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to update the index. - """ - (index_stub, index_update_request, kwargs) = self._prepare_index_update( - namespace = namespace, - name = name, - index_labels = index_labels, - hnsw_update_params = hnsw_update_params, - timeout = timeout, - logger = logger, - ) - - try: - index_stub.Update( - index_update_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to update index with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - - def index_drop( - self, *, namespace: str, name: str, timeout: Optional[int] = None - ) -> None: - """ - Drop an index. - - :param namespace: The namespace of the index. - :type name: str - - :param name: The name of the index. - :type name: str - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to drop the index. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - Note: - This method drops an index with the specified parameters and waits for the index deletion to complete. - It waits for up to 100,000 seconds for the index deletion to complete. - """ - - (index_stub, index_drop_request, kwargs) = self._prepare_index_drop( - namespace, name, timeout, logger - ) - - try: - index_stub.Drop( - index_drop_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to drop index with error: %s", e) - raise types.AVSServerError(rpc_error=e) - try: - self._wait_for_index_deletion( - namespace=namespace, name=name, timeout=100_000 - ) - except grpc.RpcError as e: - logger.error("Failed waiting for deletion with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - def index_list( - self, timeout: Optional[int] = None, apply_defaults: Optional[bool] = True - ) -> list[IndexDefinition]: - """ - List all indices. - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - :param apply_defaults: Apply default values to parameters which are not set by user. Defaults to True. - :type apply_defaults: bool - - Returns: list[dict]: A list of indices. - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list the index. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - """ - - (index_stub, index_list_request, kwargs) = self._prepare_index_list( - timeout, logger, apply_defaults - ) - - try: - response = index_stub.List( - index_list_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to list indexes with error: %s", e) - raise types.AVSServerError(rpc_error=e) - return self._respond_index_list(response) - - def index_get( - self, - *, - namespace: str, - name: str, - timeout: Optional[int] = None, - apply_defaults: Optional[bool] = True, - ) -> IndexDefinition: - """ - Retrieve the information related with an index. - - :param namespace: The namespace of the index. - :type namespace: str - - :param name: The name of the index. - :type name: str - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - :param apply_defaults: Apply default values to parameters which are not set by user. Defaults to True. - :type apply_defaults: bool - - Returns: dict[str, Union[int, str]: Information about an index. - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get the index. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - - (index_stub, index_get_request, kwargs) = self._prepare_index_get( - namespace, name, timeout, logger, apply_defaults - ) - - try: - response = index_stub.Get( - index_get_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to get index with error: %s", e) - raise types.AVSServerError(rpc_error=e) - return self._respond_index_get(response) - - def index_get_status( - self, *, namespace: str, name: str, timeout: Optional[int] = None - ) -> types.IndexStatusResponse: - """ - Retrieve the number of records queued to be merged into an index. - - :param namespace: The namespace of the index. - :type name: str - - :param name: The name of the index. - :type name: str - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - Returns: IndexStatusResponse: AVS response containing index status information. - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get the index status. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - Note: - This method retrieves the status of the specified index. If index_get_status is called the vector client puts some records into Aerospike Vector Search, - the records may not immediately begin to merge into the index. - - Warning: This API is subject to change. - """ - (index_stub, index_get_status_request, kwargs) = self._prepare_index_get_status( - namespace, name, timeout, logger - ) - - try: - response = index_stub.GetStatus( - index_get_status_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - return fromIndexStatusResponse(response) - except grpc.RpcError as e: - logger.error("Failed to get index status with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - - - - def add_user( - self, - *, - username: str, - password: str, - roles: list[str], - timeout: Optional[int] = None, - ) -> None: - """ - Add role-based access AVS User to the AVS Server. - - :param username: Username for the new user. - :type username: str - - :param password: Password for the new user. - :type password: str - - :param roles: Roles for the new user. - :type roles: list[str] - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to add a user. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - (user_admin_stub, add_user_request, kwargs) = self._prepare_add_user( - username, password, roles, timeout, logger - ) - - try: - user_admin_stub.AddUser( - add_user_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to add user with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - def update_credentials( - self, *, username: str, password: str, timeout: Optional[int] = None - ) -> None: - """ - Update AVS User credentials. - - :param username: Username of the user to update. - :type username: str - - :param password: New password for the user. - :type password: str - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to update a users credentials. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - (user_admin_stub, update_credentials_request, kwargs) = ( - self._prepare_update_credentials(username, password, timeout, logger) - ) - - try: - user_admin_stub.UpdateCredentials( - update_credentials_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to update credentials with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - def drop_user(self, *, username: str, timeout: Optional[int] = None) -> None: - """ - Drops AVS User from the AVS Server. - - :param username: Username of the user to drop. - :type username: str - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to drop a user - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - (user_admin_stub, drop_user_request, kwargs) = self._prepare_drop_user( - username, timeout, logger - ) - - try: - user_admin_stub.DropUser( - drop_user_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to drop user with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - def get_user(self, *, username: str, timeout: Optional[int] = None) -> types.User: - """ - Retrieves AVS User information from the AVS Server. - - :param username: Username of the user to be retrieved. - :type username: str - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - return: types.User: AVS User - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get a user. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - (user_admin_stub, get_user_request, kwargs) = self._prepare_get_user( - username, timeout, logger - ) - - try: - response = user_admin_stub.GetUser( - get_user_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to get user with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - return self._respond_get_user(response) - - def list_users(self, timeout: Optional[int] = None) -> list[types.User]: - """ - List all users existing on the AVS Server. - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - return: list[types.User]: list of AVS Users - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list users. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - (user_admin_stub, list_users_request, kwargs) = self._prepare_list_users( - timeout, logger - ) - - try: - response = user_admin_stub.ListUsers( - list_users_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to list user with error: %s", e) - raise types.AVSServerError(rpc_error=e) - return self._respond_list_users(response) - - def grant_roles( - self, *, username: str, roles: list[str], timeout: Optional[int] = None - ) -> None: - """ - Grant roles to existing AVS Users. - - :param username: Username of the user which will receive the roles. - :type username: str - - :param roles: Roles the specified user will receive. - :type roles: list[str] - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to grant roles. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - (user_admin_stub, grant_roles_request, kwargs) = self._prepare_grant_roles( - username, roles, timeout, logger - ) - - try: - user_admin_stub.GrantRoles( - grant_roles_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to grant roles with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - def revoke_roles( - self, *, username: str, roles: list[str], timeout: Optional[int] = None - ) -> None: - """ - Revoke roles from existing AVS Users. - - :param username: Username of the user undergoing role removal. - :type username: str - - :param roles: Roles to be revoked. - :type roles: list[str] - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to revoke roles. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - (user_admin_stub, revoke_roles_request, kwargs) = self._prepare_revoke_roles( - username, roles, timeout, logger - ) - - try: - user_admin_stub.RevokeRoles( - revoke_roles_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to revoke roles with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - def list_roles(self, timeout: Optional[int] = None) -> list[Role]: - """ - List roles available on the AVS server. - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - returns: list[str]: Roles available in the AVS Server. - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list roles. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - """ - (user_admin_stub, list_roles_request, kwargs) = self._prepare_list_roles( - timeout, logger - ) - - try: - response = user_admin_stub.ListRoles( - list_roles_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to list roles with error: %s", e) - raise types.AVSServerError(rpc_error=e) - return self._respond_list_roles(response) - - def _wait_for_index_creation( - self, - *, - namespace: str, - name: str, - timeout: int = sys.maxsize, - wait_interval: float = 0.1, - ) -> None: - """ - Wait for the index to be created. - """ - - (index_stub, wait_interval, start_time, _, _, index_creation_request) = ( - self._prepare_wait_for_index_waiting(namespace, name, wait_interval) - ) - while True: - self._check_timeout(start_time, timeout) - try: - index_stub.GetStatus( - index_creation_request, - credentials=self._channel_provider.get_token(), - ) - logger.debug("Index created successfully") - # Index has been created - return - except grpc.RpcError as e: - if e.code() == grpc.StatusCode.NOT_FOUND: - - # Wait for some more time. - time.sleep(wait_interval) - else: - logger.error("Failed waiting for index creation with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - def _wait_for_index_deletion( - self, - *, - namespace: str, - name: str, - timeout: int = sys.maxsize, - wait_interval: float = 0.1, - ) -> None: - """ - Wait for the index to be deleted. - """ - - # Wait interval between polling - (index_stub, wait_interval, start_time, _, _, index_deletion_request) = ( - self._prepare_wait_for_index_waiting(namespace, name, wait_interval) - ) - - while True: - self._check_timeout(start_time, timeout) - - try: - index_stub.GetStatus( - index_deletion_request, - credentials=self._channel_provider.get_token(), - ) - # Wait for some more time. - time.sleep(wait_interval) - except grpc.RpcError as e: - if e.code() == grpc.StatusCode.NOT_FOUND: - logger.debug("Index deleted successfully") - # Index has been created - return - else: - logger.error("Failed waiting for index deletion with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - def close(self): - """ - Close the Aerospike Vector Search Admin Client. - - This method closes gRPC channels connected to Aerospike Vector Search. - - Note: - This method should be called when the VectorDbAdminClient is no longer needed to release resources. - """ - self._channel_provider.close() - - def __enter__(self): - """ - Enter a context manager for the admin client. - - Returns: - VectorDbAdminClient: Aerospike Vector Search Admin Client instance. - """ - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """ - Exit a context manager for the admin client. - """ - self.close() diff --git a/src/aerospike_vector_search/client.py b/src/aerospike_vector_search/client.py index b78b26a1..c8292563 100644 --- a/src/aerospike_vector_search/client.py +++ b/src/aerospike_vector_search/client.py @@ -9,12 +9,14 @@ from . import types from .internal import channel_provider -from .shared.client_helpers import BaseClient +from .shared.admin_helpers import BaseClient as BaseAdminClientMixin +from .shared.client_helpers import BaseClient as BaseClientMixin +from .shared.conversions import fromIndexStatusResponse logger = logging.getLogger(__name__) -class Client(BaseClient): +class Client(BaseClientMixin, BaseAdminClientMixin): """ Aerospike Vector Search Client @@ -792,6 +794,655 @@ def wait_for_index_completion( validation_count = 0 time.sleep(wait_interval_float) + def index_create( + self, + *, + namespace: str, + name: str, + vector_field: str, + dimensions: int, + vector_distance_metric: types.VectorDistanceMetric = ( + types.VectorDistanceMetric.SQUARED_EUCLIDEAN + ), + sets: Optional[str] = None, + index_params: Optional[types.HnswParams] = None, + index_labels: Optional[dict[str, str]] = None, + index_storage: Optional[types.IndexStorage] = None, + timeout: Optional[int] = 100_000, + ) -> None: + """ + Create an index. + + :param namespace: The namespace for the index. + :type namespace: str + + :param name: The name of the index. + :type name: str + + :param vector_field: The name of the field containing vector data. + :type vector_field: str + + :param dimensions: The number of dimensions in the vector data. + :type dimensions: int + + :param vector_distance_metric: + The distance metric used to compare when performing a vector search. + Defaults to :class:`VectorDistanceMetric.SQUARED_EUCLIDEAN`. + :type vector_distance_metric: types.VectorDistanceMetric + + :param sets: The set used for the index. Defaults to None. + :type sets: Optional[str] + + :param index_params: (Optional[types.HnswParams], optional): Parameters used for tuning + vector search. Defaults to None. If index_params is None, then the default values + specified for :class:`types.HnswParams` will be used. + :type index_params: Optional[types.HnswParams] + + :param index_labels: Metadata associated with the index. Defaults to None. + :type index_labels: Optional[dict[str, str]] + + :param index_storage: Namespace and set where index overhead (non-vector data) is stored. + :type index_storage: Optional[types.IndexStorage] + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + Note: + This method creates an index with the specified parameters and waits for the index creation to complete. + It waits for up to 100,000 seconds for the index creation to complete. + """ + + (index_stub, index_create_request, kwargs) = self._prepare_index_create( + namespace, + name, + vector_field, + dimensions, + vector_distance_metric, + sets, + index_params, + index_labels, + index_storage, + timeout, + logger, + ) + + try: + index_stub.Create( + index_create_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to create index with error: %s", e) + raise types.AVSServerError(rpc_error=e) + try: + self._wait_for_index_creation( + namespace=namespace, name=name, timeout=100_000 + ) + except grpc.RpcError as e: + logger.error("Failed waiting for creation with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + def index_update( + self, + *, + namespace: str, + name: str, + index_labels: Optional[dict[str, str]] = None, + hnsw_update_params: Optional[types.HnswIndexUpdate] = None, + timeout: Optional[int] = 100_000, + ) -> None: + """ + Update an existing index. + + :param namespace: The namespace for the index. + :type namespace: str + + :param name: The name of the index. + :type name: str + + :param index_labels: Optional labels associated with the index. Defaults to None. + :type index_labels: Optional[dict[str, str]] + + :param hnsw_update_params: Parameters for updating HNSW index settings. + :type hnsw_update_params: Optional[types.HnswIndexUpdate] + + :param timeout: Time in seconds (default 100_000) this operation will wait before raising an error. + :type timeout: int + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to update the index. + """ + (index_stub, index_update_request, kwargs) = self._prepare_index_update( + namespace = namespace, + name = name, + index_labels = index_labels, + hnsw_update_params = hnsw_update_params, + timeout = timeout, + logger = logger, + ) + + try: + index_stub.Update( + index_update_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to update index with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + + def index_drop( + self, *, namespace: str, name: str, timeout: Optional[int] = None + ) -> None: + """ + Drop an index. + + :param namespace: The namespace of the index. + :type name: str + + :param name: The name of the index. + :type name: str + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to drop the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + Note: + This method drops an index with the specified parameters and waits for the index deletion to complete. + It waits for up to 100,000 seconds for the index deletion to complete. + """ + + (index_stub, index_drop_request, kwargs) = self._prepare_index_drop( + namespace, name, timeout, logger + ) + + try: + index_stub.Drop( + index_drop_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to drop index with error: %s", e) + raise types.AVSServerError(rpc_error=e) + try: + self._wait_for_index_deletion( + namespace=namespace, name=name, timeout=100_000 + ) + except grpc.RpcError as e: + logger.error("Failed waiting for deletion with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + def index_list( + self, timeout: Optional[int] = None, apply_defaults: Optional[bool] = True + ) -> list[types.IndexDefinition]: + """ + List all indices. + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + :param apply_defaults: Apply default values to parameters which are not set by user. Defaults to True. + :type apply_defaults: bool + + Returns: list[dict]: A list of indices. + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + """ + + (index_stub, index_list_request, kwargs) = self._prepare_index_list( + timeout, logger, apply_defaults + ) + + try: + response = index_stub.List( + index_list_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to list indexes with error: %s", e) + raise types.AVSServerError(rpc_error=e) + return self._respond_index_list(response) + + def index_get( + self, + *, + namespace: str, + name: str, + timeout: Optional[int] = None, + apply_defaults: Optional[bool] = True, + ) -> types.IndexDefinition: + """ + Retrieve the information related with an index. + + :param namespace: The namespace of the index. + :type namespace: str + + :param name: The name of the index. + :type name: str + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + :param apply_defaults: Apply default values to parameters which are not set by user. Defaults to True. + :type apply_defaults: bool + + Returns: dict[str, Union[int, str]: Information about an index. + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + + (index_stub, index_get_request, kwargs) = self._prepare_index_get( + namespace, name, timeout, logger, apply_defaults + ) + + try: + response = index_stub.Get( + index_get_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to get index with error: %s", e) + raise types.AVSServerError(rpc_error=e) + return self._respond_index_get(response) + + def index_get_status( + self, *, namespace: str, name: str, timeout: Optional[int] = None + ) -> types.IndexStatusResponse: + """ + Retrieve the number of records queued to be merged into an index. + + :param namespace: The namespace of the index. + :type name: str + + :param name: The name of the index. + :type name: str + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + Returns: IndexStatusResponse: AVS response containing index status information. + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get the index status. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + Note: + This method retrieves the status of the specified index. If index_get_status is called the vector client puts some records into Aerospike Vector Search, + the records may not immediately begin to merge into the index. + + Warning: This API is subject to change. + """ + (index_stub, index_get_status_request, kwargs) = self._prepare_index_get_status( + namespace, name, timeout, logger + ) + + try: + response = index_stub.GetStatus( + index_get_status_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + return fromIndexStatusResponse(response) + except grpc.RpcError as e: + logger.error("Failed to get index status with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + + + + def add_user( + self, + *, + username: str, + password: str, + roles: list[str], + timeout: Optional[int] = None, + ) -> None: + """ + Add role-based access AVS User to the AVS Server. + + :param username: Username for the new user. + :type username: str + + :param password: Password for the new user. + :type password: str + + :param roles: Roles for the new user. + :type roles: list[str] + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to add a user. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + (user_admin_stub, add_user_request, kwargs) = self._prepare_add_user( + username, password, roles, timeout, logger + ) + + try: + user_admin_stub.AddUser( + add_user_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to add user with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + def update_credentials( + self, *, username: str, password: str, timeout: Optional[int] = None + ) -> None: + """ + Update AVS User credentials. + + :param username: Username of the user to update. + :type username: str + + :param password: New password for the user. + :type password: str + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to update a users credentials. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + (user_admin_stub, update_credentials_request, kwargs) = ( + self._prepare_update_credentials(username, password, timeout, logger) + ) + + try: + user_admin_stub.UpdateCredentials( + update_credentials_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to update credentials with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + def drop_user(self, *, username: str, timeout: Optional[int] = None) -> None: + """ + Drops AVS User from the AVS Server. + + :param username: Username of the user to drop. + :type username: str + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to drop a user + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + (user_admin_stub, drop_user_request, kwargs) = self._prepare_drop_user( + username, timeout, logger + ) + + try: + user_admin_stub.DropUser( + drop_user_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to drop user with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + def get_user(self, *, username: str, timeout: Optional[int] = None) -> types.User: + """ + Retrieves AVS User information from the AVS Server. + + :param username: Username of the user to be retrieved. + :type username: str + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + return: types.User: AVS User + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get a user. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + (user_admin_stub, get_user_request, kwargs) = self._prepare_get_user( + username, timeout, logger + ) + + try: + response = user_admin_stub.GetUser( + get_user_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to get user with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + return self._respond_get_user(response) + + def list_users(self, timeout: Optional[int] = None) -> list[types.User]: + """ + List all users existing on the AVS Server. + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + return: list[types.User]: list of AVS Users + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list users. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + (user_admin_stub, list_users_request, kwargs) = self._prepare_list_users( + timeout, logger + ) + + try: + response = user_admin_stub.ListUsers( + list_users_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to list user with error: %s", e) + raise types.AVSServerError(rpc_error=e) + return self._respond_list_users(response) + + def grant_roles( + self, *, username: str, roles: list[str], timeout: Optional[int] = None + ) -> None: + """ + Grant roles to existing AVS Users. + + :param username: Username of the user which will receive the roles. + :type username: str + + :param roles: Roles the specified user will receive. + :type roles: list[str] + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to grant roles. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + (user_admin_stub, grant_roles_request, kwargs) = self._prepare_grant_roles( + username, roles, timeout, logger + ) + + try: + user_admin_stub.GrantRoles( + grant_roles_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to grant roles with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + def revoke_roles( + self, *, username: str, roles: list[str], timeout: Optional[int] = None + ) -> None: + """ + Revoke roles from existing AVS Users. + + :param username: Username of the user undergoing role removal. + :type username: str + + :param roles: Roles to be revoked. + :type roles: list[str] + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to revoke roles. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + (user_admin_stub, revoke_roles_request, kwargs) = self._prepare_revoke_roles( + username, roles, timeout, logger + ) + + try: + user_admin_stub.RevokeRoles( + revoke_roles_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to revoke roles with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + def list_roles(self, timeout: Optional[int] = None) -> list[types.Role]: + """ + List roles available on the AVS server. + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + returns: list[str]: Roles available in the AVS Server. + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list roles. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + """ + (user_admin_stub, list_roles_request, kwargs) = self._prepare_list_roles( + timeout, logger + ) + + try: + response = user_admin_stub.ListRoles( + list_roles_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to list roles with error: %s", e) + raise types.AVSServerError(rpc_error=e) + return self._respond_list_roles(response) + + def _wait_for_index_creation( + self, + *, + namespace: str, + name: str, + timeout: int = sys.maxsize, + wait_interval: float = 0.1, + ) -> None: + """ + Wait for the index to be created. + """ + + (index_stub, wait_interval, start_time, _, _, index_creation_request) = ( + self._prepare_wait_for_index_waiting(namespace, name, wait_interval) + ) + while True: + self._check_timeout(start_time, timeout) + try: + index_stub.GetStatus( + index_creation_request, + credentials=self._channel_provider.get_token(), + ) + logger.debug("Index created successfully") + # Index has been created + return + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.NOT_FOUND: + + # Wait for some more time. + time.sleep(wait_interval) + else: + logger.error("Failed waiting for index creation with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + def _wait_for_index_deletion( + self, + *, + namespace: str, + name: str, + timeout: int = sys.maxsize, + wait_interval: float = 0.1, + ) -> None: + """ + Wait for the index to be deleted. + """ + + # Wait interval between polling + (index_stub, wait_interval, start_time, _, _, index_deletion_request) = ( + self._prepare_wait_for_index_waiting(namespace, name, wait_interval) + ) + + while True: + self._check_timeout(start_time, timeout) + + try: + index_stub.GetStatus( + index_deletion_request, + credentials=self._channel_provider.get_token(), + ) + # Wait for some more time. + time.sleep(wait_interval) + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.NOT_FOUND: + logger.debug("Index deleted successfully") + # Index has been created + return + else: + logger.error("Failed waiting for index deletion with error: %s", e) + raise types.AVSServerError(rpc_error=e) + def close(self): """ Close the Aerospike Vector Search Client. diff --git a/src/aerospike_vector_search/shared/client_helpers.py b/src/aerospike_vector_search/shared/client_helpers.py index 2ae5c12d..80a831e2 100644 --- a/src/aerospike_vector_search/shared/client_helpers.py +++ b/src/aerospike_vector_search/shared/client_helpers.py @@ -352,12 +352,6 @@ def _get_key( else: raise Exception("Invalid key type" + str(type(key))) return key - - def _prepare_wait_for_index_waiting(self, namespace: str, name: str, wait_interval: int) -> ( - Tuple)[index_pb2_grpc.IndexServiceStub, float, float, bool, int, index_pb2.IndexGetRequest]: - return helpers._prepare_wait_for_index_waiting( - self, namespace, name, wait_interval - ) def _prepare_index_get_percent_unmerged(self, namespace: str, name: str, timeout: Optional[int], logger: Logger) -> ( Tuple)[index_pb2_grpc.IndexServiceStub, index_pb2.IndexStatusRequest, dict[str, Any]]: diff --git a/tests/rbac/sync/conftest.py b/tests/rbac/sync/conftest.py index 5c986c56..5b8f98ed 100644 --- a/tests/rbac/sync/conftest.py +++ b/tests/rbac/sync/conftest.py @@ -1,7 +1,6 @@ import pytest from aerospike_vector_search import Client -from aerospike_vector_search.admin import Client as AdminClient from aerospike_vector_search import types @@ -29,7 +28,7 @@ def drop_all_indexes( with open(private_key, "rb") as f: private_key = f.read() - with AdminClient( + with Client( seeds=types.HostPort(host=host, port=port), is_loadbalancer=is_loadbalancer, username=username, @@ -47,7 +46,7 @@ def drop_all_indexes( @pytest.fixture(scope="module") -def session_rbac_admin_client( +def session_rbac_client( username, password, root_certificate, @@ -70,7 +69,7 @@ def session_rbac_admin_client( with open(private_key, "rb") as f: private_key = f.read() - client = AdminClient( + client = Client( seeds=types.HostPort(host=host, port=port), is_loadbalancer=is_loadbalancer, username=username, diff --git a/tests/rbac/sync/test_admin_client_add_user.py b/tests/rbac/sync/test_admin_client_add_user.py index aa5a27c6..381b80a6 100644 --- a/tests/rbac/sync/test_admin_client_add_user.py +++ b/tests/rbac/sync/test_admin_client_add_user.py @@ -25,12 +25,12 @@ def __init__( ), ], ) -def test_add_user(session_rbac_admin_client, test_case): - session_rbac_admin_client.add_user( +def test_add_user(session_rbac_client, test_case): + session_rbac_client.add_user( username=test_case.username, password=test_case.password, roles=test_case.roles ) - result = session_rbac_admin_client.get_user(username=test_case.username) + result = session_rbac_client.get_user(username=test_case.username) assert result.username == test_case.username @@ -57,12 +57,12 @@ def test_add_user(session_rbac_admin_client, test_case): ), ], ) -def test_add_user_with_roles(session_rbac_admin_client, test_case): - session_rbac_admin_client.add_user( +def test_add_user_with_roles(session_rbac_client, test_case): + session_rbac_client.add_user( username=test_case.username, password=test_case.password, roles=test_case.roles ) - result = session_rbac_admin_client.get_user(username=test_case.username) + result = session_rbac_client.get_user(username=test_case.username) assert result.username == test_case.username diff --git a/tests/rbac/sync/test_admin_client_drop_user.py b/tests/rbac/sync/test_admin_client_drop_user.py index a18bee71..e665836a 100644 --- a/tests/rbac/sync/test_admin_client_drop_user.py +++ b/tests/rbac/sync/test_admin_client_drop_user.py @@ -24,13 +24,13 @@ def __init__( ), ], ) -def test_drop_user(session_rbac_admin_client, test_case): - session_rbac_admin_client.add_user( +def test_drop_user(session_rbac_client, test_case): + session_rbac_client.add_user( username=test_case.username, password=test_case.password, roles=None ) - session_rbac_admin_client.drop_user( + session_rbac_client.drop_user( username=test_case.username, ) with pytest.raises(AVSServerError) as e_info: - result = session_rbac_admin_client.get_user(username=test_case.username) + result = session_rbac_client.get_user(username=test_case.username) assert e_info.value.rpc_error.code() == grpc.StatusCode.NOT_FOUND diff --git a/tests/rbac/sync/test_admin_client_get_user.py b/tests/rbac/sync/test_admin_client_get_user.py index 69c76a92..61bdc405 100644 --- a/tests/rbac/sync/test_admin_client_get_user.py +++ b/tests/rbac/sync/test_admin_client_get_user.py @@ -22,12 +22,12 @@ def __init__( ), ], ) -def test_get_user(session_rbac_admin_client, test_case): - session_rbac_admin_client.add_user( +def test_get_user(session_rbac_client, test_case): + session_rbac_client.add_user( username=test_case.username, password=test_case.password, roles=None ) - result = session_rbac_admin_client.get_user(username=test_case.username) + result = session_rbac_client.get_user(username=test_case.username) assert result.username == test_case.username diff --git a/tests/rbac/sync/test_admin_client_grant_roles.py b/tests/rbac/sync/test_admin_client_grant_roles.py index dc74bef0..8ce18884 100644 --- a/tests/rbac/sync/test_admin_client_grant_roles.py +++ b/tests/rbac/sync/test_admin_client_grant_roles.py @@ -21,16 +21,16 @@ def __init__(self, *, username, password, roles, granted_roles): ), ], ) -def test_grant_roles(session_rbac_admin_client, test_case): - session_rbac_admin_client.add_user( +def test_grant_roles(session_rbac_client, test_case): + session_rbac_client.add_user( username=test_case.username, password=test_case.password, roles=test_case.roles ) - session_rbac_admin_client.grant_roles( + session_rbac_client.grant_roles( username=test_case.username, roles=test_case.granted_roles ) - result = session_rbac_admin_client.get_user(username=test_case.username) + result = session_rbac_client.get_user(username=test_case.username) assert result.username == test_case.username diff --git a/tests/rbac/sync/test_admin_client_list_roles.py b/tests/rbac/sync/test_admin_client_list_roles.py index a3d1bf82..35030910 100644 --- a/tests/rbac/sync/test_admin_client_list_roles.py +++ b/tests/rbac/sync/test_admin_client_list_roles.py @@ -25,11 +25,11 @@ def __init__( ), ], ) -def test_list_roles(session_rbac_admin_client, test_case): - session_rbac_admin_client.add_user( +def test_list_roles(session_rbac_client, test_case): + session_rbac_client.add_user( username=test_case.username, password=test_case.password, roles=test_case.roles ) - result = session_rbac_admin_client.list_roles() + result = session_rbac_client.list_roles() for role in result: assert role.id in test_case.roles diff --git a/tests/rbac/sync/test_admin_client_list_users.py b/tests/rbac/sync/test_admin_client_list_users.py index 0b713250..d029472d 100644 --- a/tests/rbac/sync/test_admin_client_list_users.py +++ b/tests/rbac/sync/test_admin_client_list_users.py @@ -17,12 +17,12 @@ def __init__(self, *, username, password): ), ], ) -def test_list_users(session_rbac_admin_client, test_case): - session_rbac_admin_client.add_user( +def test_list_users(session_rbac_client, test_case): + session_rbac_client.add_user( username=test_case.username, password=test_case.password, roles=None ) - result = session_rbac_admin_client.list_users() + result = session_rbac_client.list_users() user_found = False for user in result: if user.username == test_case.username: diff --git a/tests/rbac/sync/test_admin_client_revoke_roles.py b/tests/rbac/sync/test_admin_client_revoke_roles.py index 0620fb24..04d7d704 100644 --- a/tests/rbac/sync/test_admin_client_revoke_roles.py +++ b/tests/rbac/sync/test_admin_client_revoke_roles.py @@ -21,16 +21,16 @@ def __init__(self, *, username, password, roles, revoked_roles): ), ], ) -def test_revoke_roles(session_rbac_admin_client, test_case): - session_rbac_admin_client.add_user( +def test_revoke_roles(session_rbac_client, test_case): + session_rbac_client.add_user( username=test_case.username, password=test_case.password, roles=test_case.roles ) - session_rbac_admin_client.revoke_roles( + session_rbac_client.revoke_roles( username=test_case.username, roles=test_case.roles ) - result = session_rbac_admin_client.get_user(username=test_case.username) + result = session_rbac_client.get_user(username=test_case.username) assert result.username == test_case.username diff --git a/tests/rbac/sync/test_admin_client_update_credentials.py b/tests/rbac/sync/test_admin_client_update_credentials.py index 3e2d7894..328c05b2 100644 --- a/tests/rbac/sync/test_admin_client_update_credentials.py +++ b/tests/rbac/sync/test_admin_client_update_credentials.py @@ -19,17 +19,17 @@ def __init__(self, *, username, old_password, new_password): ), ], ) -def test_update_credentials(session_rbac_admin_client, test_case): - session_rbac_admin_client.add_user( +def test_update_credentials(session_rbac_client, test_case): + session_rbac_client.add_user( username=test_case.username, password=test_case.old_password, roles=None ) - session_rbac_admin_client.update_credentials( + session_rbac_client.update_credentials( username=test_case.username, password=test_case.new_password, ) - result = session_rbac_admin_client.get_user(username=test_case.username) + result = session_rbac_client.get_user(username=test_case.username) assert result.username == test_case.username diff --git a/tests/standard/conftest.py b/tests/standard/conftest.py index f5590e64..020b51a6 100644 --- a/tests/standard/conftest.py +++ b/tests/standard/conftest.py @@ -5,8 +5,6 @@ from aerospike_vector_search import Client from aerospike_vector_search.aio import Client as AsyncClient -from aerospike_vector_search.admin import Client as AdminClient -from aerospike_vector_search.aio.admin import Client as AsyncAdminClient from aerospike_vector_search import types, AVSServerError from utils import gen_records, DEFAULT_NAMESPACE, DEFAULT_INDEX_DIMENSION, DEFAULT_VECTOR_FIELD @@ -63,7 +61,7 @@ def drop_all_indexes( with open(private_key, "rb") as f: private_key = f.read() - with AdminClient( + with Client( seeds=types.HostPort(host=host, port=port), is_loadbalancer=is_loadbalancer, username=username, @@ -129,30 +127,6 @@ async def new_wrapped_async_client( ) -async def new_wrapped_async_admin_client( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - is_loadbalancer, - ssl_target_name_override, - loop -): - return AsyncAdminClient( - seeds=types.HostPort(host=host, port=port), - is_loadbalancer=is_loadbalancer, - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override - ) - - class AsyncClientWrapper(): def __init__(self, client, loop): self.client = client @@ -177,63 +151,6 @@ def _run_async_task(self, task): raise RuntimeError("Event loop is not running") -@pytest.fixture(scope="module") -def session_admin_client( - username, - password, - root_certificate, - host, - port, - certificate_chain, - private_key, - is_loadbalancer, - ssl_target_name_override, - async_client, - event_loop, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - if async_client: - task = new_wrapped_async_admin_client( - host=host, - port=port, - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - is_loadbalancer=is_loadbalancer, - ssl_target_name_override=ssl_target_name_override, - loop=event_loop - ) - client = asyncio.run_coroutine_threadsafe(task, event_loop).result() - client = AsyncClientWrapper(client, event_loop) - else: - client = AdminClient( - seeds=types.HostPort(host=host, port=port), - is_loadbalancer=is_loadbalancer, - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override - ) - - yield client - client.close() - - @pytest.fixture(scope="module") def session_vector_client( username, @@ -298,12 +215,12 @@ def index_name(): @pytest.fixture(params=[DEFAULT_INDEX_ARGS]) -def index(session_admin_client, index_name, request): +def index(session_vector_client, index_name, request): args = request.param namespace = args.get("namespace", DEFAULT_NAMESPACE) vector_field = args.get("vector_field", DEFAULT_VECTOR_FIELD) dimensions = args.get("dimensions", DEFAULT_INDEX_DIMENSION) - session_admin_client.index_create( + session_vector_client.index_create( name = index_name, namespace = namespace, vector_field = vector_field, @@ -324,7 +241,7 @@ def index(session_admin_client, index_name, request): ) yield index_name try: - session_admin_client.index_drop(namespace=namespace, name=index_name) + session_vector_client.index_drop(namespace=namespace, name=index_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass diff --git a/tests/standard/test_admin_client_index_create.py b/tests/standard/test_admin_client_index_create.py index d869506f..bb64ae47 100644 --- a/tests/standard/test_admin_client_index_create.py +++ b/tests/standard/test_admin_client_index_create.py @@ -62,14 +62,14 @@ def __init__( ) ], ) -def test_index_create(session_admin_client, test_case, random_name): +def test_index_create(session_vector_client, test_case, random_name): try: - session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) + session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass - session_admin_client.index_create( + session_vector_client.index_create( namespace=test_case.namespace, name=random_name, vector_field=test_case.vector_field, @@ -82,7 +82,7 @@ def test_index_create(session_admin_client, test_case, random_name): timeout=test_case.timeout, ) - results = session_admin_client.index_list() + results = session_vector_client.index_list() found = False for result in results: if result["id"]["name"] == random_name: @@ -100,7 +100,7 @@ def test_index_create(session_admin_client, test_case, random_name): assert result["storage"]["namespace"] == test_case.namespace assert result["storage"]["set_name"] == random_name assert found == True - drop_specified_index(session_admin_client, test_case.namespace, random_name) + drop_specified_index(session_vector_client, test_case.namespace, random_name) #@given(random_name=index_strategy()) @@ -132,15 +132,15 @@ def test_index_create(session_admin_client, test_case, random_name): ), ], ) -def test_index_create_with_dimnesions(session_admin_client, test_case, random_name): +def test_index_create_with_dimnesions(session_vector_client, test_case, random_name): try: - session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) + session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass - session_admin_client.index_create( + session_vector_client.index_create( namespace=test_case.namespace, name=random_name, vector_field=test_case.vector_field, @@ -153,7 +153,7 @@ def test_index_create_with_dimnesions(session_admin_client, test_case, random_na timeout=test_case.timeout, ) - results = session_admin_client.index_list() + results = session_vector_client.index_list() found = False for result in results: @@ -174,7 +174,7 @@ def test_index_create_with_dimnesions(session_admin_client, test_case, random_na assert result["storage"]["set_name"] == random_name assert found == True - drop_specified_index(session_admin_client, test_case.namespace, random_name) + drop_specified_index(session_vector_client, test_case.namespace, random_name) #@given(random_name=index_strategy()) @@ -229,16 +229,16 @@ def test_index_create_with_dimnesions(session_admin_client, test_case, random_na ], ) def test_index_create_with_vector_distance_metric( - session_admin_client, test_case, random_name + session_vector_client, test_case, random_name ): try: - session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) + session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass - session_admin_client.index_create( + session_vector_client.index_create( namespace=test_case.namespace, name=random_name, vector_field=test_case.vector_field, @@ -250,7 +250,7 @@ def test_index_create_with_vector_distance_metric( index_storage=test_case.index_storage, timeout=test_case.timeout, ) - results = session_admin_client.index_list() + results = session_vector_client.index_list() found = False for result in results: if result["id"]["name"] == random_name: @@ -268,7 +268,7 @@ def test_index_create_with_vector_distance_metric( assert result["storage"]["namespace"] == test_case.namespace assert result["storage"]["set_name"] == random_name assert found == True - drop_specified_index(session_admin_client, test_case.namespace, random_name) + drop_specified_index(session_vector_client, test_case.namespace, random_name) #@given(random_name=index_strategy()) @@ -300,15 +300,15 @@ def test_index_create_with_vector_distance_metric( ), ], ) -def test_index_create_with_sets(session_admin_client, test_case, random_name): +def test_index_create_with_sets(session_vector_client, test_case, random_name): try: - session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) + session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass - session_admin_client.index_create( + session_vector_client.index_create( namespace=test_case.namespace, name=random_name, vector_field=test_case.vector_field, @@ -320,7 +320,7 @@ def test_index_create_with_sets(session_admin_client, test_case, random_name): index_storage=test_case.index_storage, timeout=test_case.timeout, ) - results = session_admin_client.index_list() + results = session_vector_client.index_list() found = False for result in results: if result["id"]["name"] == random_name: @@ -338,7 +338,7 @@ def test_index_create_with_sets(session_admin_client, test_case, random_name): assert result["storage"]["namespace"] == test_case.namespace assert result["storage"]["set_name"] == random_name assert found == True - drop_specified_index(session_admin_client, test_case.namespace, random_name) + drop_specified_index(session_vector_client, test_case.namespace, random_name) #@given(random_name=index_strategy()) @@ -429,13 +429,13 @@ def test_index_create_with_sets(session_admin_client, test_case, random_name): ), ], ) -def test_index_create_with_index_params(session_admin_client, test_case, random_name): +def test_index_create_with_index_params(session_vector_client, test_case, random_name): try: - session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) + session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass - session_admin_client.index_create( + session_vector_client.index_create( namespace=test_case.namespace, name=random_name, vector_field=test_case.vector_field, @@ -448,7 +448,7 @@ def test_index_create_with_index_params(session_admin_client, test_case, random_ timeout=test_case.timeout, ) - results = session_admin_client.index_list() + results = session_vector_client.index_list() found = False for result in results: if result["id"]["name"] == random_name: @@ -534,7 +534,7 @@ def test_index_create_with_index_params(session_admin_client, test_case, random_ assert result["storage"]["namespace"] == test_case.namespace assert result["storage"]["set_name"] == random_name assert found == True - drop_specified_index(session_admin_client, test_case.namespace, random_name) + drop_specified_index(session_vector_client, test_case.namespace, random_name) #@given(random_name=index_strategy()) @@ -555,13 +555,13 @@ def test_index_create_with_index_params(session_admin_client, test_case, random_ ) ], ) -def test_index_create_index_labels(session_admin_client, test_case, random_name): +def test_index_create_index_labels(session_vector_client, test_case, random_name): try: - session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) + session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass - session_admin_client.index_create( + session_vector_client.index_create( namespace=test_case.namespace, name=random_name, vector_field=test_case.vector_field, @@ -574,7 +574,7 @@ def test_index_create_index_labels(session_admin_client, test_case, random_name) timeout=test_case.timeout, ) - results = session_admin_client.index_list() + results = session_vector_client.index_list() found = False for result in results: if result["id"]["name"] == random_name: @@ -595,7 +595,7 @@ def test_index_create_index_labels(session_admin_client, test_case, random_name) assert result["storage"]["namespace"] == test_case.namespace assert result["storage"]["set_name"] == random_name assert found == True - drop_specified_index(session_admin_client, test_case.namespace, random_name) + drop_specified_index(session_vector_client, test_case.namespace, random_name) #@given(random_name=index_strategy()) @@ -616,13 +616,13 @@ def test_index_create_index_labels(session_admin_client, test_case, random_name) ), ], ) -def test_index_create_index_storage(session_admin_client, test_case, random_name): +def test_index_create_index_storage(session_vector_client, test_case, random_name): try: - session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) + session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass - session_admin_client.index_create( + session_vector_client.index_create( namespace=test_case.namespace, name=random_name, vector_field=test_case.vector_field, @@ -635,7 +635,7 @@ def test_index_create_index_storage(session_admin_client, test_case, random_name timeout=test_case.timeout, ) - results = session_admin_client.index_list() + results = session_vector_client.index_list() found = False for result in results: if result["id"]["name"] == random_name: @@ -674,20 +674,20 @@ def test_index_create_index_storage(session_admin_client, test_case, random_name ], ) def test_index_create_timeout( - session_admin_client, test_case, random_name, with_latency + session_vector_client, test_case, random_name, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") try: - session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) + session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass for i in range(10): try: - session_admin_client.index_create( + session_vector_client.index_create( namespace=test_case.namespace, name=random_name, vector_field=test_case.vector_field, diff --git a/tests/standard/test_admin_client_index_drop.py b/tests/standard/test_admin_client_index_drop.py index cf05a88f..a21f90c6 100644 --- a/tests/standard/test_admin_client_index_drop.py +++ b/tests/standard/test_admin_client_index_drop.py @@ -11,10 +11,10 @@ @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -def test_index_drop(session_admin_client, empty_test_case, index): - session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=index) +def test_index_drop(session_vector_client, empty_test_case, index): + session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=index) - result = session_admin_client.index_list() + result = session_vector_client.index_list() result = result for index in result: assert index["id"]["name"] != index @@ -24,7 +24,7 @@ def test_index_drop(session_admin_client, empty_test_case, index): #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) def test_index_drop_timeout( - session_admin_client, + session_vector_client, empty_test_case, index, with_latency @@ -34,7 +34,7 @@ def test_index_drop_timeout( for i in range(10): try: - session_admin_client.index_drop( + session_vector_client.index_drop( namespace=DEFAULT_NAMESPACE, name=index, timeout=0.0001 ) except AVSServerError as se: diff --git a/tests/standard/test_admin_client_index_get.py b/tests/standard/test_admin_client_index_get.py index c6a77e00..4412b57e 100644 --- a/tests/standard/test_admin_client_index_get.py +++ b/tests/standard/test_admin_client_index_get.py @@ -8,8 +8,8 @@ @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -def test_index_get(session_admin_client, empty_test_case, index): - result = session_admin_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=True) +def test_index_get(session_vector_client, empty_test_case, index): + result = session_vector_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=True) assert result["id"]["name"] == index assert result["id"]["namespace"] == DEFAULT_NAMESPACE @@ -47,9 +47,9 @@ def test_index_get(session_admin_client, empty_test_case, index): @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -async def test_index_get_no_defaults(session_admin_client, empty_test_case, index): +async def test_index_get_no_defaults(session_vector_client, empty_test_case, index): - result = session_admin_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=False) + result = session_vector_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=False) assert result["id"]["name"] == index assert result["id"]["namespace"] == DEFAULT_NAMESPACE @@ -89,14 +89,14 @@ async def test_index_get_no_defaults(session_admin_client, empty_test_case, inde #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) def test_index_get_timeout( - session_admin_client, empty_test_case, index, with_latency + session_vector_client, empty_test_case, index, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") for i in range(10): try: - result = session_admin_client.index_get( + result = session_vector_client.index_get( namespace=DEFAULT_NAMESPACE, name=index, timeout=0.0001 ) diff --git a/tests/standard/test_admin_client_index_get_status.py b/tests/standard/test_admin_client_index_get_status.py index 93fa8c83..016c0027 100644 --- a/tests/standard/test_admin_client_index_get_status.py +++ b/tests/standard/test_admin_client_index_get_status.py @@ -13,25 +13,25 @@ @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -def test_index_get_status(session_admin_client, empty_test_case, index): - result = session_admin_client.index_get_status(namespace=DEFAULT_NAMESPACE, name=index) +def test_index_get_status(session_vector_client, empty_test_case, index): + result = session_vector_client.index_get_status(namespace=DEFAULT_NAMESPACE, name=index) assert result.unmerged_record_count == 0 - drop_specified_index(session_admin_client, DEFAULT_NAMESPACE, index) + drop_specified_index(session_vector_client, DEFAULT_NAMESPACE, index) @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) def test_index_get_status_timeout( - session_admin_client, empty_test_case, index, with_latency + session_vector_client, empty_test_case, index, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") for i in range(10): try: - result = session_admin_client.index_get_status( + result = session_vector_client.index_get_status( namespace=DEFAULT_NAMESPACE, name=index, timeout=0.0001 ) except AVSServerError as se: diff --git a/tests/standard/test_admin_client_index_list.py b/tests/standard/test_admin_client_index_list.py index 55e3978b..9c6c51f0 100644 --- a/tests/standard/test_admin_client_index_list.py +++ b/tests/standard/test_admin_client_index_list.py @@ -9,8 +9,8 @@ @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -def test_index_list(session_admin_client, empty_test_case, index): - result = session_admin_client.index_list(apply_defaults=True) +def test_index_list(session_vector_client, empty_test_case, index): + result = session_vector_client.index_list(apply_defaults=True) assert len(result) > 0 for index in result: assert isinstance(index["id"]["name"], str) @@ -32,7 +32,7 @@ def test_index_list(session_admin_client, empty_test_case, index): #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) def test_index_list_timeout( - session_admin_client, empty_test_case, with_latency + session_vector_client, empty_test_case, with_latency ): if not with_latency: @@ -41,7 +41,7 @@ def test_index_list_timeout( for i in range(10): try: - result = session_admin_client.index_list(timeout=0.0001) + result = session_vector_client.index_list(timeout=0.0001) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.DEADLINE_EXCEEDED: diff --git a/tests/standard/test_admin_client_index_update.py b/tests/standard/test_admin_client_index_update.py index e5be9822..eece36e4 100644 --- a/tests/standard/test_admin_client_index_update.py +++ b/tests/standard/test_admin_client_index_update.py @@ -44,9 +44,9 @@ def __init__( ), ], ) -def test_index_update(session_admin_client, test_case, index): +def test_index_update(session_vector_client, test_case, index): # Update the index with parameters based on the test case - session_admin_client.index_update( + session_vector_client.index_update( namespace=DEFAULT_NAMESPACE, name=index, index_labels=test_case.update_labels, @@ -57,7 +57,7 @@ def test_index_update(session_admin_client, test_case, index): time.sleep(10) # Verify the update - result = session_admin_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=True) + result = session_vector_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=True) assert result, "Expected result to be non-empty but got an empty dictionary." assert result["id"]["namespace"] == DEFAULT_NAMESPACE diff --git a/tests/standard/test_extensive_vector_search.py b/tests/standard/test_extensive_vector_search.py index 0e9efc43..6d741559 100644 --- a/tests/standard/test_extensive_vector_search.py +++ b/tests/standard/test_extensive_vector_search.py @@ -107,7 +107,6 @@ def grade_results( truth_numpy, query_numpy, session_vector_client, - session_admin_client, name, ): @@ -157,14 +156,13 @@ def test_vector_search( truth_numpy, query_numpy, session_vector_client, - session_admin_client, extensive_vector_search, ): if not extensive_vector_search: pytest.skip("Extensive vector tests disabled") - session_admin_client.index_create( + session_vector_client.index_create( namespace="test", name="demo1", vector_field="unit_test", @@ -185,7 +183,6 @@ def test_vector_search( truth_numpy, query_numpy, session_vector_client, - session_admin_client, name="demo1", ) @@ -195,14 +192,13 @@ def test_vector_search_with_set_same_as_index( truth_numpy, query_numpy, session_vector_client, - session_admin_client, extensive_vector_search, ): if not extensive_vector_search: pytest.skip("Extensive vector tests disabled") - session_admin_client.index_create( + session_vector_client.index_create( namespace="test", name="demo2", sets="demo2", @@ -228,7 +224,6 @@ def test_vector_search_with_set_same_as_index( truth_numpy, query_numpy, session_vector_client, - session_admin_client, name="demo2", ) @@ -238,14 +233,13 @@ def test_vector_search_with_set_different_than_name( truth_numpy, query_numpy, session_vector_client, - session_admin_client, extensive_vector_search, ): if not extensive_vector_search: pytest.skip("Extensive vector tests disabled") - session_admin_client.index_create( + session_vector_client.index_create( namespace="test", name="demo3", vector_field="unit_test", @@ -268,7 +262,6 @@ def test_vector_search_with_set_different_than_name( truth_numpy, query_numpy, session_vector_client, - session_admin_client, name="demo3", ) @@ -278,14 +271,13 @@ def test_vector_search_with_index_storage_different_than_name( truth_numpy, query_numpy, session_vector_client, - session_admin_client, extensive_vector_search, ): if not extensive_vector_search: pytest.skip("Extensive vector tests disabled") - session_admin_client.index_create( + session_vector_client.index_create( namespace="test", name="demo4", vector_field="unit_test", @@ -308,7 +300,6 @@ def test_vector_search_with_index_storage_different_than_name( truth_numpy, query_numpy, session_vector_client, - session_admin_client, name="demo4", ) @@ -318,14 +309,13 @@ def test_vector_search_with_index_storage_different_location( truth_numpy, query_numpy, session_vector_client, - session_admin_client, extensive_vector_search, ): if not extensive_vector_search: pytest.skip("Extensive vector tests disabled") - session_admin_client.index_create( + session_vector_client.index_create( namespace="test", name="demo5", vector_field="unit_test", @@ -348,7 +338,6 @@ def test_vector_search_with_index_storage_different_location( truth_numpy, query_numpy, session_vector_client, - session_admin_client, name="demo5", ) @@ -358,14 +347,13 @@ def test_vector_search_with_separate_namespace( truth_numpy, query_numpy, session_vector_client, - session_admin_client, extensive_vector_search, ): if not extensive_vector_search: pytest.skip("Extensive vector tests disabled") - session_admin_client.index_create( + session_vector_client.index_create( namespace="test", name="demo6", vector_field="unit_test", @@ -388,13 +376,12 @@ def test_vector_search_with_separate_namespace( truth_numpy, query_numpy, session_vector_client, - session_admin_client, name="demo6", ) def test_vector_vector_search_timeout( - session_vector_client, session_admin_client, with_latency + session_vector_client, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") diff --git a/tests/standard/test_vector_client_is_indexed.py b/tests/standard/test_vector_client_is_indexed.py index b2434bda..5284fd6e 100644 --- a/tests/standard/test_vector_client_is_indexed.py +++ b/tests/standard/test_vector_client_is_indexed.py @@ -8,14 +8,13 @@ def test_vector_is_indexed( - session_admin_client, session_vector_client, index, record, ): # wait for the record to be indexed wait_for_index( - admin_client=session_admin_client, + admin_client=session_vector_client, namespace=DEFAULT_NAMESPACE, index=index ) diff --git a/tests/standard/test_vector_client_search_by_key.py b/tests/standard/test_vector_client_search_by_key.py index 5ea3df98..d0141023 100644 --- a/tests/standard/test_vector_client_search_by_key.py +++ b/tests/standard/test_vector_client_search_by_key.py @@ -265,11 +265,10 @@ def __init__( ) def test_vector_search_by_key( session_vector_client, - session_admin_client, test_case, ): - session_admin_client.index_create( + session_vector_client.index_create( namespace=test_case.search_namespace, name=test_case.index_name, vector_field=test_case.vector_field, @@ -298,7 +297,7 @@ def test_vector_search_by_key( ) wait_for_index( - admin_client=session_admin_client, + admin_client=session_vector_client, namespace=test_case.search_namespace, index=test_case.index_name, ) @@ -324,7 +323,7 @@ def test_vector_search_by_key( key=key, ) - session_admin_client.index_drop( + session_vector_client.index_drop( namespace=test_case.search_namespace, name=test_case.index_name, ) @@ -332,10 +331,9 @@ def test_vector_search_by_key( def test_vector_search_by_key_different_namespaces( session_vector_client, - session_admin_client, ): - session_admin_client.index_create( + session_vector_client.index_create( namespace="index_storage", name="diff_ns_idx", vector_field="vec", @@ -374,7 +372,7 @@ def test_vector_search_by_key_different_namespaces( ) wait_for_index( - admin_client=session_admin_client, + admin_client=session_vector_client, namespace="index_storage", index="diff_ns_idx", ) @@ -415,7 +413,7 @@ def test_vector_search_by_key_different_namespaces( key="search_for", ) - session_admin_client.index_drop( + session_vector_client.index_drop( namespace="index_storage", name="diff_ns_idx", ) \ No newline at end of file diff --git a/tests/standard/test_vector_search.py b/tests/standard/test_vector_search.py index a18414b4..82c6326f 100644 --- a/tests/standard/test_vector_search.py +++ b/tests/standard/test_vector_search.py @@ -99,11 +99,10 @@ def __init__( ) def test_vector_search( session_vector_client, - session_admin_client, test_case, ): - session_admin_client.index_create( + session_vector_client.index_create( namespace=test_case.namespace, name=test_case.index_name, vector_field=test_case.vector_field, @@ -132,7 +131,7 @@ def test_vector_search( ) wait_for_index( - admin_client=session_admin_client, + admin_client=session_vector_client, namespace=test_case.namespace, index=test_case.index_name, ) @@ -154,7 +153,7 @@ def test_vector_search( key=key, ) - session_admin_client.index_drop( + session_vector_client.index_drop( namespace=test_case.namespace, name=test_case.index_name, ) From e98ef12307b1f0b4ee8d29d3b1f1b3af04271652 Mon Sep 17 00:00:00 2001 From: dylan Date: Thu, 2 Jan 2025 14:58:32 -0800 Subject: [PATCH 19/21] Revert "refactor:! merged sync admin and non-admin clients" This reverts commit 8717a88e068e0b94ba150e687d7c30a7943763a4. --- docs/admin.rst | 10 + src/aerospike_vector_search/__init__.py | 1 + src/aerospike_vector_search/admin.py | 756 ++++++++++++++++++ src/aerospike_vector_search/client.py | 655 +-------------- .../shared/client_helpers.py | 6 + tests/rbac/sync/conftest.py | 7 +- tests/rbac/sync/test_admin_client_add_user.py | 12 +- .../rbac/sync/test_admin_client_drop_user.py | 8 +- tests/rbac/sync/test_admin_client_get_user.py | 6 +- .../sync/test_admin_client_grant_roles.py | 8 +- .../rbac/sync/test_admin_client_list_roles.py | 6 +- .../rbac/sync/test_admin_client_list_users.py | 6 +- .../sync/test_admin_client_revoke_roles.py | 8 +- .../test_admin_client_update_credentials.py | 8 +- tests/standard/conftest.py | 91 ++- .../test_admin_client_index_create.py | 74 +- .../standard/test_admin_client_index_drop.py | 10 +- tests/standard/test_admin_client_index_get.py | 12 +- .../test_admin_client_index_get_status.py | 10 +- .../standard/test_admin_client_index_list.py | 8 +- .../test_admin_client_index_update.py | 6 +- .../standard/test_extensive_vector_search.py | 27 +- .../standard/test_vector_client_is_indexed.py | 3 +- .../test_vector_client_search_by_key.py | 14 +- tests/standard/test_vector_search.py | 7 +- 25 files changed, 991 insertions(+), 768 deletions(-) create mode 100644 docs/admin.rst create mode 100644 src/aerospike_vector_search/admin.py diff --git a/docs/admin.rst b/docs/admin.rst new file mode 100644 index 00000000..eea13706 --- /dev/null +++ b/docs/admin.rst @@ -0,0 +1,10 @@ +AdminClient +===================== + +This class is the admin client, designed to conduct AVS administrative operation such as creating indexes, querying index information, and dropping indexes. + + +.. autoclass:: aerospike_vector_search.admin.Client + :members: + :undoc-members: + :show-inheritance: diff --git a/src/aerospike_vector_search/__init__.py b/src/aerospike_vector_search/__init__.py index d08c2b3b..cc942ecb 100644 --- a/src/aerospike_vector_search/__init__.py +++ b/src/aerospike_vector_search/__init__.py @@ -1,4 +1,5 @@ from .client import Client +from .admin import Client as AdminClient from .types import ( HostPort, Key, diff --git a/src/aerospike_vector_search/admin.py b/src/aerospike_vector_search/admin.py new file mode 100644 index 00000000..a9638d09 --- /dev/null +++ b/src/aerospike_vector_search/admin.py @@ -0,0 +1,756 @@ +import logging +import sys +import time +from typing import Optional, Union + +import grpc + +from . import types +from .internal import channel_provider +from .shared.admin_helpers import BaseClient +from .shared.conversions import fromIndexStatusResponse +from .types import IndexDefinition, Role + +logger = logging.getLogger(__name__) + + +class Client(BaseClient): + """ + Aerospike Vector Search Admin Client + + This client is designed to conduct Aerospike Vector Search administrative operation such as creating indexes, querying index information, and dropping indexes. + + :param seeds: Defines the AVS nodes to which you want AVS to connect. AVS iterates through the seed nodes. After connecting to a node, AVS discovers all the nodes in the cluster. + :type seeds: Union[types.HostPort, tuple[types.HostPort, ...]] + + :param listener_name: An external (NATed) address and port combination that differs from the actual address and port where AVS is listening. Clients can access AVS on a node using the advertised listener address and port. Defaults to None. + :type listener_name: Optional[str] + + :param is_loadbalancer: If true, the first seed address will be treated as a load balancer node. Defaults to False. + :type is_loadbalancer: Optional[bool] + + :param service_config_path: Path to the service configuration file. Defaults to None. + :type service_config_path: Optional[str] + + :param username: Username for Role-Based Access. Defaults to None. + :type username: Optional[str] + + :param password: Password for Role-Based Access. Defaults to None. + :type password: Optional[str] + + :param root_certificate: The PEM-encoded root certificates as a byte string. Defaults to None. + :type root_certificate: Optional[list[bytes], bytes] + + :param certificate_chain: The PEM-encoded certificate chain as a byte string. Defaults to None. + :type certificate_chain: Optional[bytes] + + :param private_key: The PEM-encoded private key as a byte string. Defaults to None. + :type private_key: Optional[bytes] + + :raises AVSClientError: Raised when no seed host is provided. + + """ + + def __init__( + self, + *, + seeds: Union[types.HostPort, tuple[types.HostPort, ...]], + listener_name: Optional[str] = None, + is_loadbalancer: Optional[bool] = False, + username: Optional[str] = None, + password: Optional[str] = None, + root_certificate: Optional[Union[list[str], str]] = None, + certificate_chain: Optional[str] = None, + private_key: Optional[str] = None, + service_config_path: Optional[str] = None, + ssl_target_name_override: Optional[str] = None, + ) -> None: + seeds = self._prepare_seeds(seeds) + + self._channel_provider = channel_provider.ChannelProvider( + seeds, + listener_name, + is_loadbalancer, + username, + password, + root_certificate, + certificate_chain, + private_key, + service_config_path, + ssl_target_name_override, + ) + + def index_create( + self, + *, + namespace: str, + name: str, + vector_field: str, + dimensions: int, + vector_distance_metric: types.VectorDistanceMetric = ( + types.VectorDistanceMetric.SQUARED_EUCLIDEAN + ), + sets: Optional[str] = None, + index_params: Optional[types.HnswParams] = None, + index_labels: Optional[dict[str, str]] = None, + index_storage: Optional[types.IndexStorage] = None, + timeout: Optional[int] = 100_000, + ) -> None: + """ + Create an index. + + :param namespace: The namespace for the index. + :type namespace: str + + :param name: The name of the index. + :type name: str + + :param vector_field: The name of the field containing vector data. + :type vector_field: str + + :param dimensions: The number of dimensions in the vector data. + :type dimensions: int + + :param vector_distance_metric: + The distance metric used to compare when performing a vector search. + Defaults to :class:`VectorDistanceMetric.SQUARED_EUCLIDEAN`. + :type vector_distance_metric: types.VectorDistanceMetric + + :param sets: The set used for the index. Defaults to None. + :type sets: Optional[str] + + :param index_params: (Optional[types.HnswParams], optional): Parameters used for tuning + vector search. Defaults to None. If index_params is None, then the default values + specified for :class:`types.HnswParams` will be used. + :type index_params: Optional[types.HnswParams] + + :param index_labels: Metadata associated with the index. Defaults to None. + :type index_labels: Optional[dict[str, str]] + + :param index_storage: Namespace and set where index overhead (non-vector data) is stored. + :type index_storage: Optional[types.IndexStorage] + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + Note: + This method creates an index with the specified parameters and waits for the index creation to complete. + It waits for up to 100,000 seconds for the index creation to complete. + """ + + (index_stub, index_create_request, kwargs) = self._prepare_index_create( + namespace, + name, + vector_field, + dimensions, + vector_distance_metric, + sets, + index_params, + index_labels, + index_storage, + timeout, + logger, + ) + + try: + index_stub.Create( + index_create_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to create index with error: %s", e) + raise types.AVSServerError(rpc_error=e) + try: + self._wait_for_index_creation( + namespace=namespace, name=name, timeout=100_000 + ) + except grpc.RpcError as e: + logger.error("Failed waiting for creation with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + def index_update( + self, + *, + namespace: str, + name: str, + index_labels: Optional[dict[str, str]] = None, + hnsw_update_params: Optional[types.HnswIndexUpdate] = None, + timeout: Optional[int] = 100_000, + ) -> None: + """ + Update an existing index. + + :param namespace: The namespace for the index. + :type namespace: str + + :param name: The name of the index. + :type name: str + + :param index_labels: Optional labels associated with the index. Defaults to None. + :type index_labels: Optional[dict[str, str]] + + :param hnsw_update_params: Parameters for updating HNSW index settings. + :type hnsw_update_params: Optional[types.HnswIndexUpdate] + + :param timeout: Time in seconds (default 100_000) this operation will wait before raising an error. + :type timeout: int + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to update the index. + """ + (index_stub, index_update_request, kwargs) = self._prepare_index_update( + namespace = namespace, + name = name, + index_labels = index_labels, + hnsw_update_params = hnsw_update_params, + timeout = timeout, + logger = logger, + ) + + try: + index_stub.Update( + index_update_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to update index with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + + def index_drop( + self, *, namespace: str, name: str, timeout: Optional[int] = None + ) -> None: + """ + Drop an index. + + :param namespace: The namespace of the index. + :type name: str + + :param name: The name of the index. + :type name: str + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to drop the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + Note: + This method drops an index with the specified parameters and waits for the index deletion to complete. + It waits for up to 100,000 seconds for the index deletion to complete. + """ + + (index_stub, index_drop_request, kwargs) = self._prepare_index_drop( + namespace, name, timeout, logger + ) + + try: + index_stub.Drop( + index_drop_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to drop index with error: %s", e) + raise types.AVSServerError(rpc_error=e) + try: + self._wait_for_index_deletion( + namespace=namespace, name=name, timeout=100_000 + ) + except grpc.RpcError as e: + logger.error("Failed waiting for deletion with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + def index_list( + self, timeout: Optional[int] = None, apply_defaults: Optional[bool] = True + ) -> list[IndexDefinition]: + """ + List all indices. + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + :param apply_defaults: Apply default values to parameters which are not set by user. Defaults to True. + :type apply_defaults: bool + + Returns: list[dict]: A list of indices. + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + """ + + (index_stub, index_list_request, kwargs) = self._prepare_index_list( + timeout, logger, apply_defaults + ) + + try: + response = index_stub.List( + index_list_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to list indexes with error: %s", e) + raise types.AVSServerError(rpc_error=e) + return self._respond_index_list(response) + + def index_get( + self, + *, + namespace: str, + name: str, + timeout: Optional[int] = None, + apply_defaults: Optional[bool] = True, + ) -> IndexDefinition: + """ + Retrieve the information related with an index. + + :param namespace: The namespace of the index. + :type namespace: str + + :param name: The name of the index. + :type name: str + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + :param apply_defaults: Apply default values to parameters which are not set by user. Defaults to True. + :type apply_defaults: bool + + Returns: dict[str, Union[int, str]: Information about an index. + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + + (index_stub, index_get_request, kwargs) = self._prepare_index_get( + namespace, name, timeout, logger, apply_defaults + ) + + try: + response = index_stub.Get( + index_get_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to get index with error: %s", e) + raise types.AVSServerError(rpc_error=e) + return self._respond_index_get(response) + + def index_get_status( + self, *, namespace: str, name: str, timeout: Optional[int] = None + ) -> types.IndexStatusResponse: + """ + Retrieve the number of records queued to be merged into an index. + + :param namespace: The namespace of the index. + :type name: str + + :param name: The name of the index. + :type name: str + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + Returns: IndexStatusResponse: AVS response containing index status information. + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get the index status. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + Note: + This method retrieves the status of the specified index. If index_get_status is called the vector client puts some records into Aerospike Vector Search, + the records may not immediately begin to merge into the index. + + Warning: This API is subject to change. + """ + (index_stub, index_get_status_request, kwargs) = self._prepare_index_get_status( + namespace, name, timeout, logger + ) + + try: + response = index_stub.GetStatus( + index_get_status_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + return fromIndexStatusResponse(response) + except grpc.RpcError as e: + logger.error("Failed to get index status with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + + + + def add_user( + self, + *, + username: str, + password: str, + roles: list[str], + timeout: Optional[int] = None, + ) -> None: + """ + Add role-based access AVS User to the AVS Server. + + :param username: Username for the new user. + :type username: str + + :param password: Password for the new user. + :type password: str + + :param roles: Roles for the new user. + :type roles: list[str] + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to add a user. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + (user_admin_stub, add_user_request, kwargs) = self._prepare_add_user( + username, password, roles, timeout, logger + ) + + try: + user_admin_stub.AddUser( + add_user_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to add user with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + def update_credentials( + self, *, username: str, password: str, timeout: Optional[int] = None + ) -> None: + """ + Update AVS User credentials. + + :param username: Username of the user to update. + :type username: str + + :param password: New password for the user. + :type password: str + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to update a users credentials. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + (user_admin_stub, update_credentials_request, kwargs) = ( + self._prepare_update_credentials(username, password, timeout, logger) + ) + + try: + user_admin_stub.UpdateCredentials( + update_credentials_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to update credentials with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + def drop_user(self, *, username: str, timeout: Optional[int] = None) -> None: + """ + Drops AVS User from the AVS Server. + + :param username: Username of the user to drop. + :type username: str + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to drop a user + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + (user_admin_stub, drop_user_request, kwargs) = self._prepare_drop_user( + username, timeout, logger + ) + + try: + user_admin_stub.DropUser( + drop_user_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to drop user with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + def get_user(self, *, username: str, timeout: Optional[int] = None) -> types.User: + """ + Retrieves AVS User information from the AVS Server. + + :param username: Username of the user to be retrieved. + :type username: str + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + return: types.User: AVS User + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get a user. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + (user_admin_stub, get_user_request, kwargs) = self._prepare_get_user( + username, timeout, logger + ) + + try: + response = user_admin_stub.GetUser( + get_user_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to get user with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + return self._respond_get_user(response) + + def list_users(self, timeout: Optional[int] = None) -> list[types.User]: + """ + List all users existing on the AVS Server. + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + return: list[types.User]: list of AVS Users + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list users. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + (user_admin_stub, list_users_request, kwargs) = self._prepare_list_users( + timeout, logger + ) + + try: + response = user_admin_stub.ListUsers( + list_users_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to list user with error: %s", e) + raise types.AVSServerError(rpc_error=e) + return self._respond_list_users(response) + + def grant_roles( + self, *, username: str, roles: list[str], timeout: Optional[int] = None + ) -> None: + """ + Grant roles to existing AVS Users. + + :param username: Username of the user which will receive the roles. + :type username: str + + :param roles: Roles the specified user will receive. + :type roles: list[str] + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to grant roles. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + (user_admin_stub, grant_roles_request, kwargs) = self._prepare_grant_roles( + username, roles, timeout, logger + ) + + try: + user_admin_stub.GrantRoles( + grant_roles_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to grant roles with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + def revoke_roles( + self, *, username: str, roles: list[str], timeout: Optional[int] = None + ) -> None: + """ + Revoke roles from existing AVS Users. + + :param username: Username of the user undergoing role removal. + :type username: str + + :param roles: Roles to be revoked. + :type roles: list[str] + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to revoke roles. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + (user_admin_stub, revoke_roles_request, kwargs) = self._prepare_revoke_roles( + username, roles, timeout, logger + ) + + try: + user_admin_stub.RevokeRoles( + revoke_roles_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to revoke roles with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + def list_roles(self, timeout: Optional[int] = None) -> list[Role]: + """ + List roles available on the AVS server. + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + returns: list[str]: Roles available in the AVS Server. + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list roles. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + """ + (user_admin_stub, list_roles_request, kwargs) = self._prepare_list_roles( + timeout, logger + ) + + try: + response = user_admin_stub.ListRoles( + list_roles_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to list roles with error: %s", e) + raise types.AVSServerError(rpc_error=e) + return self._respond_list_roles(response) + + def _wait_for_index_creation( + self, + *, + namespace: str, + name: str, + timeout: int = sys.maxsize, + wait_interval: float = 0.1, + ) -> None: + """ + Wait for the index to be created. + """ + + (index_stub, wait_interval, start_time, _, _, index_creation_request) = ( + self._prepare_wait_for_index_waiting(namespace, name, wait_interval) + ) + while True: + self._check_timeout(start_time, timeout) + try: + index_stub.GetStatus( + index_creation_request, + credentials=self._channel_provider.get_token(), + ) + logger.debug("Index created successfully") + # Index has been created + return + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.NOT_FOUND: + + # Wait for some more time. + time.sleep(wait_interval) + else: + logger.error("Failed waiting for index creation with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + def _wait_for_index_deletion( + self, + *, + namespace: str, + name: str, + timeout: int = sys.maxsize, + wait_interval: float = 0.1, + ) -> None: + """ + Wait for the index to be deleted. + """ + + # Wait interval between polling + (index_stub, wait_interval, start_time, _, _, index_deletion_request) = ( + self._prepare_wait_for_index_waiting(namespace, name, wait_interval) + ) + + while True: + self._check_timeout(start_time, timeout) + + try: + index_stub.GetStatus( + index_deletion_request, + credentials=self._channel_provider.get_token(), + ) + # Wait for some more time. + time.sleep(wait_interval) + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.NOT_FOUND: + logger.debug("Index deleted successfully") + # Index has been created + return + else: + logger.error("Failed waiting for index deletion with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + def close(self): + """ + Close the Aerospike Vector Search Admin Client. + + This method closes gRPC channels connected to Aerospike Vector Search. + + Note: + This method should be called when the VectorDbAdminClient is no longer needed to release resources. + """ + self._channel_provider.close() + + def __enter__(self): + """ + Enter a context manager for the admin client. + + Returns: + VectorDbAdminClient: Aerospike Vector Search Admin Client instance. + """ + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Exit a context manager for the admin client. + """ + self.close() diff --git a/src/aerospike_vector_search/client.py b/src/aerospike_vector_search/client.py index c8292563..b78b26a1 100644 --- a/src/aerospike_vector_search/client.py +++ b/src/aerospike_vector_search/client.py @@ -9,14 +9,12 @@ from . import types from .internal import channel_provider -from .shared.admin_helpers import BaseClient as BaseAdminClientMixin -from .shared.client_helpers import BaseClient as BaseClientMixin -from .shared.conversions import fromIndexStatusResponse +from .shared.client_helpers import BaseClient logger = logging.getLogger(__name__) -class Client(BaseClientMixin, BaseAdminClientMixin): +class Client(BaseClient): """ Aerospike Vector Search Client @@ -794,655 +792,6 @@ def wait_for_index_completion( validation_count = 0 time.sleep(wait_interval_float) - def index_create( - self, - *, - namespace: str, - name: str, - vector_field: str, - dimensions: int, - vector_distance_metric: types.VectorDistanceMetric = ( - types.VectorDistanceMetric.SQUARED_EUCLIDEAN - ), - sets: Optional[str] = None, - index_params: Optional[types.HnswParams] = None, - index_labels: Optional[dict[str, str]] = None, - index_storage: Optional[types.IndexStorage] = None, - timeout: Optional[int] = 100_000, - ) -> None: - """ - Create an index. - - :param namespace: The namespace for the index. - :type namespace: str - - :param name: The name of the index. - :type name: str - - :param vector_field: The name of the field containing vector data. - :type vector_field: str - - :param dimensions: The number of dimensions in the vector data. - :type dimensions: int - - :param vector_distance_metric: - The distance metric used to compare when performing a vector search. - Defaults to :class:`VectorDistanceMetric.SQUARED_EUCLIDEAN`. - :type vector_distance_metric: types.VectorDistanceMetric - - :param sets: The set used for the index. Defaults to None. - :type sets: Optional[str] - - :param index_params: (Optional[types.HnswParams], optional): Parameters used for tuning - vector search. Defaults to None. If index_params is None, then the default values - specified for :class:`types.HnswParams` will be used. - :type index_params: Optional[types.HnswParams] - - :param index_labels: Metadata associated with the index. Defaults to None. - :type index_labels: Optional[dict[str, str]] - - :param index_storage: Namespace and set where index overhead (non-vector data) is stored. - :type index_storage: Optional[types.IndexStorage] - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - Note: - This method creates an index with the specified parameters and waits for the index creation to complete. - It waits for up to 100,000 seconds for the index creation to complete. - """ - - (index_stub, index_create_request, kwargs) = self._prepare_index_create( - namespace, - name, - vector_field, - dimensions, - vector_distance_metric, - sets, - index_params, - index_labels, - index_storage, - timeout, - logger, - ) - - try: - index_stub.Create( - index_create_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to create index with error: %s", e) - raise types.AVSServerError(rpc_error=e) - try: - self._wait_for_index_creation( - namespace=namespace, name=name, timeout=100_000 - ) - except grpc.RpcError as e: - logger.error("Failed waiting for creation with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - def index_update( - self, - *, - namespace: str, - name: str, - index_labels: Optional[dict[str, str]] = None, - hnsw_update_params: Optional[types.HnswIndexUpdate] = None, - timeout: Optional[int] = 100_000, - ) -> None: - """ - Update an existing index. - - :param namespace: The namespace for the index. - :type namespace: str - - :param name: The name of the index. - :type name: str - - :param index_labels: Optional labels associated with the index. Defaults to None. - :type index_labels: Optional[dict[str, str]] - - :param hnsw_update_params: Parameters for updating HNSW index settings. - :type hnsw_update_params: Optional[types.HnswIndexUpdate] - - :param timeout: Time in seconds (default 100_000) this operation will wait before raising an error. - :type timeout: int - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to update the index. - """ - (index_stub, index_update_request, kwargs) = self._prepare_index_update( - namespace = namespace, - name = name, - index_labels = index_labels, - hnsw_update_params = hnsw_update_params, - timeout = timeout, - logger = logger, - ) - - try: - index_stub.Update( - index_update_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to update index with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - - def index_drop( - self, *, namespace: str, name: str, timeout: Optional[int] = None - ) -> None: - """ - Drop an index. - - :param namespace: The namespace of the index. - :type name: str - - :param name: The name of the index. - :type name: str - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to drop the index. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - Note: - This method drops an index with the specified parameters and waits for the index deletion to complete. - It waits for up to 100,000 seconds for the index deletion to complete. - """ - - (index_stub, index_drop_request, kwargs) = self._prepare_index_drop( - namespace, name, timeout, logger - ) - - try: - index_stub.Drop( - index_drop_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to drop index with error: %s", e) - raise types.AVSServerError(rpc_error=e) - try: - self._wait_for_index_deletion( - namespace=namespace, name=name, timeout=100_000 - ) - except grpc.RpcError as e: - logger.error("Failed waiting for deletion with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - def index_list( - self, timeout: Optional[int] = None, apply_defaults: Optional[bool] = True - ) -> list[types.IndexDefinition]: - """ - List all indices. - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - :param apply_defaults: Apply default values to parameters which are not set by user. Defaults to True. - :type apply_defaults: bool - - Returns: list[dict]: A list of indices. - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list the index. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - """ - - (index_stub, index_list_request, kwargs) = self._prepare_index_list( - timeout, logger, apply_defaults - ) - - try: - response = index_stub.List( - index_list_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to list indexes with error: %s", e) - raise types.AVSServerError(rpc_error=e) - return self._respond_index_list(response) - - def index_get( - self, - *, - namespace: str, - name: str, - timeout: Optional[int] = None, - apply_defaults: Optional[bool] = True, - ) -> types.IndexDefinition: - """ - Retrieve the information related with an index. - - :param namespace: The namespace of the index. - :type namespace: str - - :param name: The name of the index. - :type name: str - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - :param apply_defaults: Apply default values to parameters which are not set by user. Defaults to True. - :type apply_defaults: bool - - Returns: dict[str, Union[int, str]: Information about an index. - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get the index. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - - (index_stub, index_get_request, kwargs) = self._prepare_index_get( - namespace, name, timeout, logger, apply_defaults - ) - - try: - response = index_stub.Get( - index_get_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to get index with error: %s", e) - raise types.AVSServerError(rpc_error=e) - return self._respond_index_get(response) - - def index_get_status( - self, *, namespace: str, name: str, timeout: Optional[int] = None - ) -> types.IndexStatusResponse: - """ - Retrieve the number of records queued to be merged into an index. - - :param namespace: The namespace of the index. - :type name: str - - :param name: The name of the index. - :type name: str - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - Returns: IndexStatusResponse: AVS response containing index status information. - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get the index status. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - Note: - This method retrieves the status of the specified index. If index_get_status is called the vector client puts some records into Aerospike Vector Search, - the records may not immediately begin to merge into the index. - - Warning: This API is subject to change. - """ - (index_stub, index_get_status_request, kwargs) = self._prepare_index_get_status( - namespace, name, timeout, logger - ) - - try: - response = index_stub.GetStatus( - index_get_status_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - return fromIndexStatusResponse(response) - except grpc.RpcError as e: - logger.error("Failed to get index status with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - - - - def add_user( - self, - *, - username: str, - password: str, - roles: list[str], - timeout: Optional[int] = None, - ) -> None: - """ - Add role-based access AVS User to the AVS Server. - - :param username: Username for the new user. - :type username: str - - :param password: Password for the new user. - :type password: str - - :param roles: Roles for the new user. - :type roles: list[str] - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to add a user. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - (user_admin_stub, add_user_request, kwargs) = self._prepare_add_user( - username, password, roles, timeout, logger - ) - - try: - user_admin_stub.AddUser( - add_user_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to add user with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - def update_credentials( - self, *, username: str, password: str, timeout: Optional[int] = None - ) -> None: - """ - Update AVS User credentials. - - :param username: Username of the user to update. - :type username: str - - :param password: New password for the user. - :type password: str - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to update a users credentials. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - (user_admin_stub, update_credentials_request, kwargs) = ( - self._prepare_update_credentials(username, password, timeout, logger) - ) - - try: - user_admin_stub.UpdateCredentials( - update_credentials_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to update credentials with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - def drop_user(self, *, username: str, timeout: Optional[int] = None) -> None: - """ - Drops AVS User from the AVS Server. - - :param username: Username of the user to drop. - :type username: str - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to drop a user - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - (user_admin_stub, drop_user_request, kwargs) = self._prepare_drop_user( - username, timeout, logger - ) - - try: - user_admin_stub.DropUser( - drop_user_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to drop user with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - def get_user(self, *, username: str, timeout: Optional[int] = None) -> types.User: - """ - Retrieves AVS User information from the AVS Server. - - :param username: Username of the user to be retrieved. - :type username: str - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - return: types.User: AVS User - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get a user. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - (user_admin_stub, get_user_request, kwargs) = self._prepare_get_user( - username, timeout, logger - ) - - try: - response = user_admin_stub.GetUser( - get_user_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to get user with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - return self._respond_get_user(response) - - def list_users(self, timeout: Optional[int] = None) -> list[types.User]: - """ - List all users existing on the AVS Server. - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - return: list[types.User]: list of AVS Users - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list users. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - (user_admin_stub, list_users_request, kwargs) = self._prepare_list_users( - timeout, logger - ) - - try: - response = user_admin_stub.ListUsers( - list_users_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to list user with error: %s", e) - raise types.AVSServerError(rpc_error=e) - return self._respond_list_users(response) - - def grant_roles( - self, *, username: str, roles: list[str], timeout: Optional[int] = None - ) -> None: - """ - Grant roles to existing AVS Users. - - :param username: Username of the user which will receive the roles. - :type username: str - - :param roles: Roles the specified user will receive. - :type roles: list[str] - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to grant roles. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - (user_admin_stub, grant_roles_request, kwargs) = self._prepare_grant_roles( - username, roles, timeout, logger - ) - - try: - user_admin_stub.GrantRoles( - grant_roles_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to grant roles with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - def revoke_roles( - self, *, username: str, roles: list[str], timeout: Optional[int] = None - ) -> None: - """ - Revoke roles from existing AVS Users. - - :param username: Username of the user undergoing role removal. - :type username: str - - :param roles: Roles to be revoked. - :type roles: list[str] - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to revoke roles. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - (user_admin_stub, revoke_roles_request, kwargs) = self._prepare_revoke_roles( - username, roles, timeout, logger - ) - - try: - user_admin_stub.RevokeRoles( - revoke_roles_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to revoke roles with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - def list_roles(self, timeout: Optional[int] = None) -> list[types.Role]: - """ - List roles available on the AVS server. - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - returns: list[str]: Roles available in the AVS Server. - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list roles. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - """ - (user_admin_stub, list_roles_request, kwargs) = self._prepare_list_roles( - timeout, logger - ) - - try: - response = user_admin_stub.ListRoles( - list_roles_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to list roles with error: %s", e) - raise types.AVSServerError(rpc_error=e) - return self._respond_list_roles(response) - - def _wait_for_index_creation( - self, - *, - namespace: str, - name: str, - timeout: int = sys.maxsize, - wait_interval: float = 0.1, - ) -> None: - """ - Wait for the index to be created. - """ - - (index_stub, wait_interval, start_time, _, _, index_creation_request) = ( - self._prepare_wait_for_index_waiting(namespace, name, wait_interval) - ) - while True: - self._check_timeout(start_time, timeout) - try: - index_stub.GetStatus( - index_creation_request, - credentials=self._channel_provider.get_token(), - ) - logger.debug("Index created successfully") - # Index has been created - return - except grpc.RpcError as e: - if e.code() == grpc.StatusCode.NOT_FOUND: - - # Wait for some more time. - time.sleep(wait_interval) - else: - logger.error("Failed waiting for index creation with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - def _wait_for_index_deletion( - self, - *, - namespace: str, - name: str, - timeout: int = sys.maxsize, - wait_interval: float = 0.1, - ) -> None: - """ - Wait for the index to be deleted. - """ - - # Wait interval between polling - (index_stub, wait_interval, start_time, _, _, index_deletion_request) = ( - self._prepare_wait_for_index_waiting(namespace, name, wait_interval) - ) - - while True: - self._check_timeout(start_time, timeout) - - try: - index_stub.GetStatus( - index_deletion_request, - credentials=self._channel_provider.get_token(), - ) - # Wait for some more time. - time.sleep(wait_interval) - except grpc.RpcError as e: - if e.code() == grpc.StatusCode.NOT_FOUND: - logger.debug("Index deleted successfully") - # Index has been created - return - else: - logger.error("Failed waiting for index deletion with error: %s", e) - raise types.AVSServerError(rpc_error=e) - def close(self): """ Close the Aerospike Vector Search Client. diff --git a/src/aerospike_vector_search/shared/client_helpers.py b/src/aerospike_vector_search/shared/client_helpers.py index 80a831e2..2ae5c12d 100644 --- a/src/aerospike_vector_search/shared/client_helpers.py +++ b/src/aerospike_vector_search/shared/client_helpers.py @@ -352,6 +352,12 @@ def _get_key( else: raise Exception("Invalid key type" + str(type(key))) return key + + def _prepare_wait_for_index_waiting(self, namespace: str, name: str, wait_interval: int) -> ( + Tuple)[index_pb2_grpc.IndexServiceStub, float, float, bool, int, index_pb2.IndexGetRequest]: + return helpers._prepare_wait_for_index_waiting( + self, namespace, name, wait_interval + ) def _prepare_index_get_percent_unmerged(self, namespace: str, name: str, timeout: Optional[int], logger: Logger) -> ( Tuple)[index_pb2_grpc.IndexServiceStub, index_pb2.IndexStatusRequest, dict[str, Any]]: diff --git a/tests/rbac/sync/conftest.py b/tests/rbac/sync/conftest.py index 5b8f98ed..5c986c56 100644 --- a/tests/rbac/sync/conftest.py +++ b/tests/rbac/sync/conftest.py @@ -1,6 +1,7 @@ import pytest from aerospike_vector_search import Client +from aerospike_vector_search.admin import Client as AdminClient from aerospike_vector_search import types @@ -28,7 +29,7 @@ def drop_all_indexes( with open(private_key, "rb") as f: private_key = f.read() - with Client( + with AdminClient( seeds=types.HostPort(host=host, port=port), is_loadbalancer=is_loadbalancer, username=username, @@ -46,7 +47,7 @@ def drop_all_indexes( @pytest.fixture(scope="module") -def session_rbac_client( +def session_rbac_admin_client( username, password, root_certificate, @@ -69,7 +70,7 @@ def session_rbac_client( with open(private_key, "rb") as f: private_key = f.read() - client = Client( + client = AdminClient( seeds=types.HostPort(host=host, port=port), is_loadbalancer=is_loadbalancer, username=username, diff --git a/tests/rbac/sync/test_admin_client_add_user.py b/tests/rbac/sync/test_admin_client_add_user.py index 381b80a6..aa5a27c6 100644 --- a/tests/rbac/sync/test_admin_client_add_user.py +++ b/tests/rbac/sync/test_admin_client_add_user.py @@ -25,12 +25,12 @@ def __init__( ), ], ) -def test_add_user(session_rbac_client, test_case): - session_rbac_client.add_user( +def test_add_user(session_rbac_admin_client, test_case): + session_rbac_admin_client.add_user( username=test_case.username, password=test_case.password, roles=test_case.roles ) - result = session_rbac_client.get_user(username=test_case.username) + result = session_rbac_admin_client.get_user(username=test_case.username) assert result.username == test_case.username @@ -57,12 +57,12 @@ def test_add_user(session_rbac_client, test_case): ), ], ) -def test_add_user_with_roles(session_rbac_client, test_case): - session_rbac_client.add_user( +def test_add_user_with_roles(session_rbac_admin_client, test_case): + session_rbac_admin_client.add_user( username=test_case.username, password=test_case.password, roles=test_case.roles ) - result = session_rbac_client.get_user(username=test_case.username) + result = session_rbac_admin_client.get_user(username=test_case.username) assert result.username == test_case.username diff --git a/tests/rbac/sync/test_admin_client_drop_user.py b/tests/rbac/sync/test_admin_client_drop_user.py index e665836a..a18bee71 100644 --- a/tests/rbac/sync/test_admin_client_drop_user.py +++ b/tests/rbac/sync/test_admin_client_drop_user.py @@ -24,13 +24,13 @@ def __init__( ), ], ) -def test_drop_user(session_rbac_client, test_case): - session_rbac_client.add_user( +def test_drop_user(session_rbac_admin_client, test_case): + session_rbac_admin_client.add_user( username=test_case.username, password=test_case.password, roles=None ) - session_rbac_client.drop_user( + session_rbac_admin_client.drop_user( username=test_case.username, ) with pytest.raises(AVSServerError) as e_info: - result = session_rbac_client.get_user(username=test_case.username) + result = session_rbac_admin_client.get_user(username=test_case.username) assert e_info.value.rpc_error.code() == grpc.StatusCode.NOT_FOUND diff --git a/tests/rbac/sync/test_admin_client_get_user.py b/tests/rbac/sync/test_admin_client_get_user.py index 61bdc405..69c76a92 100644 --- a/tests/rbac/sync/test_admin_client_get_user.py +++ b/tests/rbac/sync/test_admin_client_get_user.py @@ -22,12 +22,12 @@ def __init__( ), ], ) -def test_get_user(session_rbac_client, test_case): - session_rbac_client.add_user( +def test_get_user(session_rbac_admin_client, test_case): + session_rbac_admin_client.add_user( username=test_case.username, password=test_case.password, roles=None ) - result = session_rbac_client.get_user(username=test_case.username) + result = session_rbac_admin_client.get_user(username=test_case.username) assert result.username == test_case.username diff --git a/tests/rbac/sync/test_admin_client_grant_roles.py b/tests/rbac/sync/test_admin_client_grant_roles.py index 8ce18884..dc74bef0 100644 --- a/tests/rbac/sync/test_admin_client_grant_roles.py +++ b/tests/rbac/sync/test_admin_client_grant_roles.py @@ -21,16 +21,16 @@ def __init__(self, *, username, password, roles, granted_roles): ), ], ) -def test_grant_roles(session_rbac_client, test_case): - session_rbac_client.add_user( +def test_grant_roles(session_rbac_admin_client, test_case): + session_rbac_admin_client.add_user( username=test_case.username, password=test_case.password, roles=test_case.roles ) - session_rbac_client.grant_roles( + session_rbac_admin_client.grant_roles( username=test_case.username, roles=test_case.granted_roles ) - result = session_rbac_client.get_user(username=test_case.username) + result = session_rbac_admin_client.get_user(username=test_case.username) assert result.username == test_case.username diff --git a/tests/rbac/sync/test_admin_client_list_roles.py b/tests/rbac/sync/test_admin_client_list_roles.py index 35030910..a3d1bf82 100644 --- a/tests/rbac/sync/test_admin_client_list_roles.py +++ b/tests/rbac/sync/test_admin_client_list_roles.py @@ -25,11 +25,11 @@ def __init__( ), ], ) -def test_list_roles(session_rbac_client, test_case): - session_rbac_client.add_user( +def test_list_roles(session_rbac_admin_client, test_case): + session_rbac_admin_client.add_user( username=test_case.username, password=test_case.password, roles=test_case.roles ) - result = session_rbac_client.list_roles() + result = session_rbac_admin_client.list_roles() for role in result: assert role.id in test_case.roles diff --git a/tests/rbac/sync/test_admin_client_list_users.py b/tests/rbac/sync/test_admin_client_list_users.py index d029472d..0b713250 100644 --- a/tests/rbac/sync/test_admin_client_list_users.py +++ b/tests/rbac/sync/test_admin_client_list_users.py @@ -17,12 +17,12 @@ def __init__(self, *, username, password): ), ], ) -def test_list_users(session_rbac_client, test_case): - session_rbac_client.add_user( +def test_list_users(session_rbac_admin_client, test_case): + session_rbac_admin_client.add_user( username=test_case.username, password=test_case.password, roles=None ) - result = session_rbac_client.list_users() + result = session_rbac_admin_client.list_users() user_found = False for user in result: if user.username == test_case.username: diff --git a/tests/rbac/sync/test_admin_client_revoke_roles.py b/tests/rbac/sync/test_admin_client_revoke_roles.py index 04d7d704..0620fb24 100644 --- a/tests/rbac/sync/test_admin_client_revoke_roles.py +++ b/tests/rbac/sync/test_admin_client_revoke_roles.py @@ -21,16 +21,16 @@ def __init__(self, *, username, password, roles, revoked_roles): ), ], ) -def test_revoke_roles(session_rbac_client, test_case): - session_rbac_client.add_user( +def test_revoke_roles(session_rbac_admin_client, test_case): + session_rbac_admin_client.add_user( username=test_case.username, password=test_case.password, roles=test_case.roles ) - session_rbac_client.revoke_roles( + session_rbac_admin_client.revoke_roles( username=test_case.username, roles=test_case.roles ) - result = session_rbac_client.get_user(username=test_case.username) + result = session_rbac_admin_client.get_user(username=test_case.username) assert result.username == test_case.username diff --git a/tests/rbac/sync/test_admin_client_update_credentials.py b/tests/rbac/sync/test_admin_client_update_credentials.py index 328c05b2..3e2d7894 100644 --- a/tests/rbac/sync/test_admin_client_update_credentials.py +++ b/tests/rbac/sync/test_admin_client_update_credentials.py @@ -19,17 +19,17 @@ def __init__(self, *, username, old_password, new_password): ), ], ) -def test_update_credentials(session_rbac_client, test_case): - session_rbac_client.add_user( +def test_update_credentials(session_rbac_admin_client, test_case): + session_rbac_admin_client.add_user( username=test_case.username, password=test_case.old_password, roles=None ) - session_rbac_client.update_credentials( + session_rbac_admin_client.update_credentials( username=test_case.username, password=test_case.new_password, ) - result = session_rbac_client.get_user(username=test_case.username) + result = session_rbac_admin_client.get_user(username=test_case.username) assert result.username == test_case.username diff --git a/tests/standard/conftest.py b/tests/standard/conftest.py index 020b51a6..f5590e64 100644 --- a/tests/standard/conftest.py +++ b/tests/standard/conftest.py @@ -5,6 +5,8 @@ from aerospike_vector_search import Client from aerospike_vector_search.aio import Client as AsyncClient +from aerospike_vector_search.admin import Client as AdminClient +from aerospike_vector_search.aio.admin import Client as AsyncAdminClient from aerospike_vector_search import types, AVSServerError from utils import gen_records, DEFAULT_NAMESPACE, DEFAULT_INDEX_DIMENSION, DEFAULT_VECTOR_FIELD @@ -61,7 +63,7 @@ def drop_all_indexes( with open(private_key, "rb") as f: private_key = f.read() - with Client( + with AdminClient( seeds=types.HostPort(host=host, port=port), is_loadbalancer=is_loadbalancer, username=username, @@ -127,6 +129,30 @@ async def new_wrapped_async_client( ) +async def new_wrapped_async_admin_client( + host, + port, + username, + password, + root_certificate, + certificate_chain, + private_key, + is_loadbalancer, + ssl_target_name_override, + loop +): + return AsyncAdminClient( + seeds=types.HostPort(host=host, port=port), + is_loadbalancer=is_loadbalancer, + username=username, + password=password, + root_certificate=root_certificate, + certificate_chain=certificate_chain, + private_key=private_key, + ssl_target_name_override=ssl_target_name_override + ) + + class AsyncClientWrapper(): def __init__(self, client, loop): self.client = client @@ -151,6 +177,63 @@ def _run_async_task(self, task): raise RuntimeError("Event loop is not running") +@pytest.fixture(scope="module") +def session_admin_client( + username, + password, + root_certificate, + host, + port, + certificate_chain, + private_key, + is_loadbalancer, + ssl_target_name_override, + async_client, + event_loop, +): + + if root_certificate: + with open(root_certificate, "rb") as f: + root_certificate = f.read() + + if certificate_chain: + with open(certificate_chain, "rb") as f: + certificate_chain = f.read() + if private_key: + with open(private_key, "rb") as f: + private_key = f.read() + + if async_client: + task = new_wrapped_async_admin_client( + host=host, + port=port, + username=username, + password=password, + root_certificate=root_certificate, + certificate_chain=certificate_chain, + private_key=private_key, + is_loadbalancer=is_loadbalancer, + ssl_target_name_override=ssl_target_name_override, + loop=event_loop + ) + client = asyncio.run_coroutine_threadsafe(task, event_loop).result() + client = AsyncClientWrapper(client, event_loop) + else: + client = AdminClient( + seeds=types.HostPort(host=host, port=port), + is_loadbalancer=is_loadbalancer, + username=username, + password=password, + root_certificate=root_certificate, + certificate_chain=certificate_chain, + private_key=private_key, + ssl_target_name_override=ssl_target_name_override + ) + + yield client + client.close() + + @pytest.fixture(scope="module") def session_vector_client( username, @@ -215,12 +298,12 @@ def index_name(): @pytest.fixture(params=[DEFAULT_INDEX_ARGS]) -def index(session_vector_client, index_name, request): +def index(session_admin_client, index_name, request): args = request.param namespace = args.get("namespace", DEFAULT_NAMESPACE) vector_field = args.get("vector_field", DEFAULT_VECTOR_FIELD) dimensions = args.get("dimensions", DEFAULT_INDEX_DIMENSION) - session_vector_client.index_create( + session_admin_client.index_create( name = index_name, namespace = namespace, vector_field = vector_field, @@ -241,7 +324,7 @@ def index(session_vector_client, index_name, request): ) yield index_name try: - session_vector_client.index_drop(namespace=namespace, name=index_name) + session_admin_client.index_drop(namespace=namespace, name=index_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass diff --git a/tests/standard/test_admin_client_index_create.py b/tests/standard/test_admin_client_index_create.py index bb64ae47..d869506f 100644 --- a/tests/standard/test_admin_client_index_create.py +++ b/tests/standard/test_admin_client_index_create.py @@ -62,14 +62,14 @@ def __init__( ) ], ) -def test_index_create(session_vector_client, test_case, random_name): +def test_index_create(session_admin_client, test_case, random_name): try: - session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass - session_vector_client.index_create( + session_admin_client.index_create( namespace=test_case.namespace, name=random_name, vector_field=test_case.vector_field, @@ -82,7 +82,7 @@ def test_index_create(session_vector_client, test_case, random_name): timeout=test_case.timeout, ) - results = session_vector_client.index_list() + results = session_admin_client.index_list() found = False for result in results: if result["id"]["name"] == random_name: @@ -100,7 +100,7 @@ def test_index_create(session_vector_client, test_case, random_name): assert result["storage"]["namespace"] == test_case.namespace assert result["storage"]["set_name"] == random_name assert found == True - drop_specified_index(session_vector_client, test_case.namespace, random_name) + drop_specified_index(session_admin_client, test_case.namespace, random_name) #@given(random_name=index_strategy()) @@ -132,15 +132,15 @@ def test_index_create(session_vector_client, test_case, random_name): ), ], ) -def test_index_create_with_dimnesions(session_vector_client, test_case, random_name): +def test_index_create_with_dimnesions(session_admin_client, test_case, random_name): try: - session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass - session_vector_client.index_create( + session_admin_client.index_create( namespace=test_case.namespace, name=random_name, vector_field=test_case.vector_field, @@ -153,7 +153,7 @@ def test_index_create_with_dimnesions(session_vector_client, test_case, random_n timeout=test_case.timeout, ) - results = session_vector_client.index_list() + results = session_admin_client.index_list() found = False for result in results: @@ -174,7 +174,7 @@ def test_index_create_with_dimnesions(session_vector_client, test_case, random_n assert result["storage"]["set_name"] == random_name assert found == True - drop_specified_index(session_vector_client, test_case.namespace, random_name) + drop_specified_index(session_admin_client, test_case.namespace, random_name) #@given(random_name=index_strategy()) @@ -229,16 +229,16 @@ def test_index_create_with_dimnesions(session_vector_client, test_case, random_n ], ) def test_index_create_with_vector_distance_metric( - session_vector_client, test_case, random_name + session_admin_client, test_case, random_name ): try: - session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass - session_vector_client.index_create( + session_admin_client.index_create( namespace=test_case.namespace, name=random_name, vector_field=test_case.vector_field, @@ -250,7 +250,7 @@ def test_index_create_with_vector_distance_metric( index_storage=test_case.index_storage, timeout=test_case.timeout, ) - results = session_vector_client.index_list() + results = session_admin_client.index_list() found = False for result in results: if result["id"]["name"] == random_name: @@ -268,7 +268,7 @@ def test_index_create_with_vector_distance_metric( assert result["storage"]["namespace"] == test_case.namespace assert result["storage"]["set_name"] == random_name assert found == True - drop_specified_index(session_vector_client, test_case.namespace, random_name) + drop_specified_index(session_admin_client, test_case.namespace, random_name) #@given(random_name=index_strategy()) @@ -300,15 +300,15 @@ def test_index_create_with_vector_distance_metric( ), ], ) -def test_index_create_with_sets(session_vector_client, test_case, random_name): +def test_index_create_with_sets(session_admin_client, test_case, random_name): try: - session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass - session_vector_client.index_create( + session_admin_client.index_create( namespace=test_case.namespace, name=random_name, vector_field=test_case.vector_field, @@ -320,7 +320,7 @@ def test_index_create_with_sets(session_vector_client, test_case, random_name): index_storage=test_case.index_storage, timeout=test_case.timeout, ) - results = session_vector_client.index_list() + results = session_admin_client.index_list() found = False for result in results: if result["id"]["name"] == random_name: @@ -338,7 +338,7 @@ def test_index_create_with_sets(session_vector_client, test_case, random_name): assert result["storage"]["namespace"] == test_case.namespace assert result["storage"]["set_name"] == random_name assert found == True - drop_specified_index(session_vector_client, test_case.namespace, random_name) + drop_specified_index(session_admin_client, test_case.namespace, random_name) #@given(random_name=index_strategy()) @@ -429,13 +429,13 @@ def test_index_create_with_sets(session_vector_client, test_case, random_name): ), ], ) -def test_index_create_with_index_params(session_vector_client, test_case, random_name): +def test_index_create_with_index_params(session_admin_client, test_case, random_name): try: - session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass - session_vector_client.index_create( + session_admin_client.index_create( namespace=test_case.namespace, name=random_name, vector_field=test_case.vector_field, @@ -448,7 +448,7 @@ def test_index_create_with_index_params(session_vector_client, test_case, random timeout=test_case.timeout, ) - results = session_vector_client.index_list() + results = session_admin_client.index_list() found = False for result in results: if result["id"]["name"] == random_name: @@ -534,7 +534,7 @@ def test_index_create_with_index_params(session_vector_client, test_case, random assert result["storage"]["namespace"] == test_case.namespace assert result["storage"]["set_name"] == random_name assert found == True - drop_specified_index(session_vector_client, test_case.namespace, random_name) + drop_specified_index(session_admin_client, test_case.namespace, random_name) #@given(random_name=index_strategy()) @@ -555,13 +555,13 @@ def test_index_create_with_index_params(session_vector_client, test_case, random ) ], ) -def test_index_create_index_labels(session_vector_client, test_case, random_name): +def test_index_create_index_labels(session_admin_client, test_case, random_name): try: - session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass - session_vector_client.index_create( + session_admin_client.index_create( namespace=test_case.namespace, name=random_name, vector_field=test_case.vector_field, @@ -574,7 +574,7 @@ def test_index_create_index_labels(session_vector_client, test_case, random_name timeout=test_case.timeout, ) - results = session_vector_client.index_list() + results = session_admin_client.index_list() found = False for result in results: if result["id"]["name"] == random_name: @@ -595,7 +595,7 @@ def test_index_create_index_labels(session_vector_client, test_case, random_name assert result["storage"]["namespace"] == test_case.namespace assert result["storage"]["set_name"] == random_name assert found == True - drop_specified_index(session_vector_client, test_case.namespace, random_name) + drop_specified_index(session_admin_client, test_case.namespace, random_name) #@given(random_name=index_strategy()) @@ -616,13 +616,13 @@ def test_index_create_index_labels(session_vector_client, test_case, random_name ), ], ) -def test_index_create_index_storage(session_vector_client, test_case, random_name): +def test_index_create_index_storage(session_admin_client, test_case, random_name): try: - session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass - session_vector_client.index_create( + session_admin_client.index_create( namespace=test_case.namespace, name=random_name, vector_field=test_case.vector_field, @@ -635,7 +635,7 @@ def test_index_create_index_storage(session_vector_client, test_case, random_nam timeout=test_case.timeout, ) - results = session_vector_client.index_list() + results = session_admin_client.index_list() found = False for result in results: if result["id"]["name"] == random_name: @@ -674,20 +674,20 @@ def test_index_create_index_storage(session_vector_client, test_case, random_nam ], ) def test_index_create_timeout( - session_vector_client, test_case, random_name, with_latency + session_admin_client, test_case, random_name, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") try: - session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass for i in range(10): try: - session_vector_client.index_create( + session_admin_client.index_create( namespace=test_case.namespace, name=random_name, vector_field=test_case.vector_field, diff --git a/tests/standard/test_admin_client_index_drop.py b/tests/standard/test_admin_client_index_drop.py index a21f90c6..cf05a88f 100644 --- a/tests/standard/test_admin_client_index_drop.py +++ b/tests/standard/test_admin_client_index_drop.py @@ -11,10 +11,10 @@ @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -def test_index_drop(session_vector_client, empty_test_case, index): - session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=index) +def test_index_drop(session_admin_client, empty_test_case, index): + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=index) - result = session_vector_client.index_list() + result = session_admin_client.index_list() result = result for index in result: assert index["id"]["name"] != index @@ -24,7 +24,7 @@ def test_index_drop(session_vector_client, empty_test_case, index): #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) def test_index_drop_timeout( - session_vector_client, + session_admin_client, empty_test_case, index, with_latency @@ -34,7 +34,7 @@ def test_index_drop_timeout( for i in range(10): try: - session_vector_client.index_drop( + session_admin_client.index_drop( namespace=DEFAULT_NAMESPACE, name=index, timeout=0.0001 ) except AVSServerError as se: diff --git a/tests/standard/test_admin_client_index_get.py b/tests/standard/test_admin_client_index_get.py index 4412b57e..c6a77e00 100644 --- a/tests/standard/test_admin_client_index_get.py +++ b/tests/standard/test_admin_client_index_get.py @@ -8,8 +8,8 @@ @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -def test_index_get(session_vector_client, empty_test_case, index): - result = session_vector_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=True) +def test_index_get(session_admin_client, empty_test_case, index): + result = session_admin_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=True) assert result["id"]["name"] == index assert result["id"]["namespace"] == DEFAULT_NAMESPACE @@ -47,9 +47,9 @@ def test_index_get(session_vector_client, empty_test_case, index): @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -async def test_index_get_no_defaults(session_vector_client, empty_test_case, index): +async def test_index_get_no_defaults(session_admin_client, empty_test_case, index): - result = session_vector_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=False) + result = session_admin_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=False) assert result["id"]["name"] == index assert result["id"]["namespace"] == DEFAULT_NAMESPACE @@ -89,14 +89,14 @@ async def test_index_get_no_defaults(session_vector_client, empty_test_case, ind #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) def test_index_get_timeout( - session_vector_client, empty_test_case, index, with_latency + session_admin_client, empty_test_case, index, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") for i in range(10): try: - result = session_vector_client.index_get( + result = session_admin_client.index_get( namespace=DEFAULT_NAMESPACE, name=index, timeout=0.0001 ) diff --git a/tests/standard/test_admin_client_index_get_status.py b/tests/standard/test_admin_client_index_get_status.py index 016c0027..93fa8c83 100644 --- a/tests/standard/test_admin_client_index_get_status.py +++ b/tests/standard/test_admin_client_index_get_status.py @@ -13,25 +13,25 @@ @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -def test_index_get_status(session_vector_client, empty_test_case, index): - result = session_vector_client.index_get_status(namespace=DEFAULT_NAMESPACE, name=index) +def test_index_get_status(session_admin_client, empty_test_case, index): + result = session_admin_client.index_get_status(namespace=DEFAULT_NAMESPACE, name=index) assert result.unmerged_record_count == 0 - drop_specified_index(session_vector_client, DEFAULT_NAMESPACE, index) + drop_specified_index(session_admin_client, DEFAULT_NAMESPACE, index) @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) def test_index_get_status_timeout( - session_vector_client, empty_test_case, index, with_latency + session_admin_client, empty_test_case, index, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") for i in range(10): try: - result = session_vector_client.index_get_status( + result = session_admin_client.index_get_status( namespace=DEFAULT_NAMESPACE, name=index, timeout=0.0001 ) except AVSServerError as se: diff --git a/tests/standard/test_admin_client_index_list.py b/tests/standard/test_admin_client_index_list.py index 9c6c51f0..55e3978b 100644 --- a/tests/standard/test_admin_client_index_list.py +++ b/tests/standard/test_admin_client_index_list.py @@ -9,8 +9,8 @@ @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -def test_index_list(session_vector_client, empty_test_case, index): - result = session_vector_client.index_list(apply_defaults=True) +def test_index_list(session_admin_client, empty_test_case, index): + result = session_admin_client.index_list(apply_defaults=True) assert len(result) > 0 for index in result: assert isinstance(index["id"]["name"], str) @@ -32,7 +32,7 @@ def test_index_list(session_vector_client, empty_test_case, index): #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) def test_index_list_timeout( - session_vector_client, empty_test_case, with_latency + session_admin_client, empty_test_case, with_latency ): if not with_latency: @@ -41,7 +41,7 @@ def test_index_list_timeout( for i in range(10): try: - result = session_vector_client.index_list(timeout=0.0001) + result = session_admin_client.index_list(timeout=0.0001) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.DEADLINE_EXCEEDED: diff --git a/tests/standard/test_admin_client_index_update.py b/tests/standard/test_admin_client_index_update.py index eece36e4..e5be9822 100644 --- a/tests/standard/test_admin_client_index_update.py +++ b/tests/standard/test_admin_client_index_update.py @@ -44,9 +44,9 @@ def __init__( ), ], ) -def test_index_update(session_vector_client, test_case, index): +def test_index_update(session_admin_client, test_case, index): # Update the index with parameters based on the test case - session_vector_client.index_update( + session_admin_client.index_update( namespace=DEFAULT_NAMESPACE, name=index, index_labels=test_case.update_labels, @@ -57,7 +57,7 @@ def test_index_update(session_vector_client, test_case, index): time.sleep(10) # Verify the update - result = session_vector_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=True) + result = session_admin_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=True) assert result, "Expected result to be non-empty but got an empty dictionary." assert result["id"]["namespace"] == DEFAULT_NAMESPACE diff --git a/tests/standard/test_extensive_vector_search.py b/tests/standard/test_extensive_vector_search.py index 6d741559..0e9efc43 100644 --- a/tests/standard/test_extensive_vector_search.py +++ b/tests/standard/test_extensive_vector_search.py @@ -107,6 +107,7 @@ def grade_results( truth_numpy, query_numpy, session_vector_client, + session_admin_client, name, ): @@ -156,13 +157,14 @@ def test_vector_search( truth_numpy, query_numpy, session_vector_client, + session_admin_client, extensive_vector_search, ): if not extensive_vector_search: pytest.skip("Extensive vector tests disabled") - session_vector_client.index_create( + session_admin_client.index_create( namespace="test", name="demo1", vector_field="unit_test", @@ -183,6 +185,7 @@ def test_vector_search( truth_numpy, query_numpy, session_vector_client, + session_admin_client, name="demo1", ) @@ -192,13 +195,14 @@ def test_vector_search_with_set_same_as_index( truth_numpy, query_numpy, session_vector_client, + session_admin_client, extensive_vector_search, ): if not extensive_vector_search: pytest.skip("Extensive vector tests disabled") - session_vector_client.index_create( + session_admin_client.index_create( namespace="test", name="demo2", sets="demo2", @@ -224,6 +228,7 @@ def test_vector_search_with_set_same_as_index( truth_numpy, query_numpy, session_vector_client, + session_admin_client, name="demo2", ) @@ -233,13 +238,14 @@ def test_vector_search_with_set_different_than_name( truth_numpy, query_numpy, session_vector_client, + session_admin_client, extensive_vector_search, ): if not extensive_vector_search: pytest.skip("Extensive vector tests disabled") - session_vector_client.index_create( + session_admin_client.index_create( namespace="test", name="demo3", vector_field="unit_test", @@ -262,6 +268,7 @@ def test_vector_search_with_set_different_than_name( truth_numpy, query_numpy, session_vector_client, + session_admin_client, name="demo3", ) @@ -271,13 +278,14 @@ def test_vector_search_with_index_storage_different_than_name( truth_numpy, query_numpy, session_vector_client, + session_admin_client, extensive_vector_search, ): if not extensive_vector_search: pytest.skip("Extensive vector tests disabled") - session_vector_client.index_create( + session_admin_client.index_create( namespace="test", name="demo4", vector_field="unit_test", @@ -300,6 +308,7 @@ def test_vector_search_with_index_storage_different_than_name( truth_numpy, query_numpy, session_vector_client, + session_admin_client, name="demo4", ) @@ -309,13 +318,14 @@ def test_vector_search_with_index_storage_different_location( truth_numpy, query_numpy, session_vector_client, + session_admin_client, extensive_vector_search, ): if not extensive_vector_search: pytest.skip("Extensive vector tests disabled") - session_vector_client.index_create( + session_admin_client.index_create( namespace="test", name="demo5", vector_field="unit_test", @@ -338,6 +348,7 @@ def test_vector_search_with_index_storage_different_location( truth_numpy, query_numpy, session_vector_client, + session_admin_client, name="demo5", ) @@ -347,13 +358,14 @@ def test_vector_search_with_separate_namespace( truth_numpy, query_numpy, session_vector_client, + session_admin_client, extensive_vector_search, ): if not extensive_vector_search: pytest.skip("Extensive vector tests disabled") - session_vector_client.index_create( + session_admin_client.index_create( namespace="test", name="demo6", vector_field="unit_test", @@ -376,12 +388,13 @@ def test_vector_search_with_separate_namespace( truth_numpy, query_numpy, session_vector_client, + session_admin_client, name="demo6", ) def test_vector_vector_search_timeout( - session_vector_client, with_latency + session_vector_client, session_admin_client, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") diff --git a/tests/standard/test_vector_client_is_indexed.py b/tests/standard/test_vector_client_is_indexed.py index 5284fd6e..b2434bda 100644 --- a/tests/standard/test_vector_client_is_indexed.py +++ b/tests/standard/test_vector_client_is_indexed.py @@ -8,13 +8,14 @@ def test_vector_is_indexed( + session_admin_client, session_vector_client, index, record, ): # wait for the record to be indexed wait_for_index( - admin_client=session_vector_client, + admin_client=session_admin_client, namespace=DEFAULT_NAMESPACE, index=index ) diff --git a/tests/standard/test_vector_client_search_by_key.py b/tests/standard/test_vector_client_search_by_key.py index d0141023..5ea3df98 100644 --- a/tests/standard/test_vector_client_search_by_key.py +++ b/tests/standard/test_vector_client_search_by_key.py @@ -265,10 +265,11 @@ def __init__( ) def test_vector_search_by_key( session_vector_client, + session_admin_client, test_case, ): - session_vector_client.index_create( + session_admin_client.index_create( namespace=test_case.search_namespace, name=test_case.index_name, vector_field=test_case.vector_field, @@ -297,7 +298,7 @@ def test_vector_search_by_key( ) wait_for_index( - admin_client=session_vector_client, + admin_client=session_admin_client, namespace=test_case.search_namespace, index=test_case.index_name, ) @@ -323,7 +324,7 @@ def test_vector_search_by_key( key=key, ) - session_vector_client.index_drop( + session_admin_client.index_drop( namespace=test_case.search_namespace, name=test_case.index_name, ) @@ -331,9 +332,10 @@ def test_vector_search_by_key( def test_vector_search_by_key_different_namespaces( session_vector_client, + session_admin_client, ): - session_vector_client.index_create( + session_admin_client.index_create( namespace="index_storage", name="diff_ns_idx", vector_field="vec", @@ -372,7 +374,7 @@ def test_vector_search_by_key_different_namespaces( ) wait_for_index( - admin_client=session_vector_client, + admin_client=session_admin_client, namespace="index_storage", index="diff_ns_idx", ) @@ -413,7 +415,7 @@ def test_vector_search_by_key_different_namespaces( key="search_for", ) - session_vector_client.index_drop( + session_admin_client.index_drop( namespace="index_storage", name="diff_ns_idx", ) \ No newline at end of file diff --git a/tests/standard/test_vector_search.py b/tests/standard/test_vector_search.py index 82c6326f..a18414b4 100644 --- a/tests/standard/test_vector_search.py +++ b/tests/standard/test_vector_search.py @@ -99,10 +99,11 @@ def __init__( ) def test_vector_search( session_vector_client, + session_admin_client, test_case, ): - session_vector_client.index_create( + session_admin_client.index_create( namespace=test_case.namespace, name=test_case.index_name, vector_field=test_case.vector_field, @@ -131,7 +132,7 @@ def test_vector_search( ) wait_for_index( - admin_client=session_vector_client, + admin_client=session_admin_client, namespace=test_case.namespace, index=test_case.index_name, ) @@ -153,7 +154,7 @@ def test_vector_search( key=key, ) - session_vector_client.index_drop( + session_admin_client.index_drop( namespace=test_case.namespace, name=test_case.index_name, ) From 386e931e12b66fcac41f459944ee959260841694 Mon Sep 17 00:00:00 2001 From: dwelch-spike <53876192+dwelch-spike@users.noreply.github.com> Date: Wed, 8 Jan 2025 11:46:36 -0800 Subject: [PATCH 20/21] refactor:! merge async admin and vector client (#71) --- src/aerospike_vector_search/aio/__init__.py | 1 - src/aerospike_vector_search/aio/admin.py | 793 ------------------ src/aerospike_vector_search/aio/client.py | 694 ++++++++++++++- src/aerospike_vector_search/client.py | 2 +- tests/rbac/aio/conftest.py | 7 +- tests/rbac/aio/test_admin_client_add_user.py | 12 +- tests/rbac/aio/test_admin_client_drop_user.py | 8 +- tests/rbac/aio/test_admin_client_get_user.py | 6 +- .../rbac/aio/test_admin_client_grant_roles.py | 8 +- .../rbac/aio/test_admin_client_list_roles.py | 6 +- .../rbac/aio/test_admin_client_list_users.py | 6 +- .../aio/test_admin_client_revoke_roles.py | 8 +- .../test_admin_client_update_credentials.py | 8 +- 13 files changed, 726 insertions(+), 833 deletions(-) delete mode 100644 src/aerospike_vector_search/aio/admin.py diff --git a/src/aerospike_vector_search/aio/__init__.py b/src/aerospike_vector_search/aio/__init__.py index eb50a5a8..caad5d72 100644 --- a/src/aerospike_vector_search/aio/__init__.py +++ b/src/aerospike_vector_search/aio/__init__.py @@ -1,5 +1,4 @@ from .client import Client -from .admin import Client as AdminClient from ..types import ( HostPort, Key, diff --git a/src/aerospike_vector_search/aio/admin.py b/src/aerospike_vector_search/aio/admin.py deleted file mode 100644 index aa57bcd1..00000000 --- a/src/aerospike_vector_search/aio/admin.py +++ /dev/null @@ -1,793 +0,0 @@ -import asyncio -import logging -import sys -from typing import Optional, Union - -import grpc - -from .internal import channel_provider -from .. import types -from ..shared.admin_helpers import BaseClient -from ..shared.conversions import fromIndexStatusResponse -from ..types import Role, IndexDefinition - -logger = logging.getLogger(__name__) - - -class Client(BaseClient): - """ - Aerospike Vector Search Asyncio Admin Client - - This client is designed to conduct Aerospike Vector Search administrative operation such as creating indexes, querying index information, and dropping indexes. - - :param seeds: Defines the AVS nodes to which you want AVS to connect. AVS iterates through the seed nodes. After connecting to a node, AVS discovers all the nodes in the cluster. - :type seeds: Union[types.HostPort, tuple[types.HostPort, ...]] - - :param listener_name: An external (NATed) address and port combination that differs from the actual address and port where AVS is listening. Clients can access AVS on a node using the advertised listener address and port. Defaults to None. - :type listener_name: Optional[str] - - :param is_loadbalancer: If true, the first seed address will be treated as a load balancer node. Defaults to False. - :type is_loadbalancer: Optional[bool] - - :param service_config_path: Path to the service configuration file. Defaults to None. - :type service_config_path: Optional[str] - - :param username: Username for Role-Based Access. Defaults to None. - :type username: Optional[str] - - :param password: Password for Role-Based Access. Defaults to None. - :type password: Optional[str] - - :param root_certificate: The PEM-encoded root certificates as a byte string. Defaults to None. - :type root_certificate: Optional[list[bytes], bytes] - - :param certificate_chain: The PEM-encoded certificate chain as a byte string. Defaults to None. - :type certificate_chain: Optional[bytes] - - :param private_key: The PEM-encoded private key as a byte string. Defaults to None. - :type private_key: Optional[bytes] - - :raises AVSClientError: Raised when no seed host is provided. - - """ - - def __init__( - self, - *, - seeds: Union[types.HostPort, tuple[types.HostPort, ...]], - listener_name: Optional[str] = None, - is_loadbalancer: Optional[bool] = False, - service_config_path: Optional[str] = None, - username: Optional[str] = None, - password: Optional[str] = None, - root_certificate: Optional[Union[list[str], str]] = None, - certificate_chain: Optional[str] = None, - private_key: Optional[str] = None, - ssl_target_name_override: Optional[str] = None, - ) -> None: - seeds = self._prepare_seeds(seeds) - - self._channel_provider = channel_provider.ChannelProvider( - seeds, - listener_name, - is_loadbalancer, - username, - password, - root_certificate, - certificate_chain, - private_key, - service_config_path, - ssl_target_name_override, - ) - - async def index_create( - self, - *, - namespace: str, - name: str, - vector_field: str, - dimensions: int, - vector_distance_metric: types.VectorDistanceMetric = ( - types.VectorDistanceMetric.SQUARED_EUCLIDEAN - ), - sets: Optional[str] = None, - index_params: Optional[types.HnswParams] = None, - index_labels: Optional[dict[str, str]] = None, - index_storage: Optional[types.IndexStorage] = None, - timeout: Optional[int] = 100_000, - ) -> None: - """ - Create an index. - - :param namespace: The namespace for the index. - :type namespace: str - - :param name: The name of the index. - :type name: str - - :param vector_field: The name of the field containing vector data. - :type vector_field: str - - :param dimensions: The number of dimensions in the vector data. - :type dimensions: int - - :param vector_distance_metric: - The distance metric used to compare when performing a vector search. - Defaults to :class:`VectorDistanceMetric.SQUARED_EUCLIDEAN`. - :type vector_distance_metric: types.VectorDistanceMetric - - :param sets: The set used for the index. Defaults to None. - :type sets: Optional[str] - - :param index_params: (Optional[types.HnswParams], optional): Parameters used for tuning - vector search. Defaults to None. If index_params is None, then the default values - specified for :class:`types.HnswParams` will be used. - :type index_params: Optional[types.HnswParams] - - :param index_labels: Metadata associated with the index. Defaults to None. - :type index_labels: Optional[dict[str, str]] - - :param index_storage: Namespace and set where index overhead (non-vector data) is stored. - :type index_storage: Optional[types.IndexStorage] - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - Note: - This method creates an index with the specified parameters and waits for the index creation to complete. - It waits for up to 100,000 seconds or the specified timeout for the index creation to complete. - """ - - await self._channel_provider._is_ready() - - (index_stub, index_create_request, kwargs) = self._prepare_index_create( - namespace, - name, - vector_field, - dimensions, - vector_distance_metric, - sets, - index_params, - index_labels, - index_storage, - timeout, - logger, - ) - - try: - await index_stub.Create( - index_create_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to create index with error: %s", e) - raise types.AVSServerError(rpc_error=e) - try: - await self._wait_for_index_creation( - namespace=namespace, name=name, timeout=100_000 - ) - except grpc.RpcError as e: - logger.error("Failed waiting for creation with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - - async def index_update( - self, - *, - namespace: str, - name: str, - index_labels: Optional[dict[str, str]] = None, - hnsw_update_params: Optional[types.HnswIndexUpdate] = None, - timeout: Optional[int] = 100_000, - ) -> None: - """ - Update an existing index. - - :param namespace: The namespace for the index. - :type namespace: str - - :param name: The name of the index. - :type name: str - - :param index_labels: Optional labels associated with the index. Defaults to None. - :type index_labels: Optional[dict[str, str]] - - :param hnsw_update_params: Parameters for updating HNSW index settings. - :type hnsw_update_params: Optional[types.HnswIndexUpdate] - - :param timeout: Timeout in seconds for internal index update tasks. Defaults to 100_000. - :type timeout: int - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to update the index. - """ - - await self._channel_provider._is_ready() - - (index_stub, index_update_request, kwargs) = self._prepare_index_update( - namespace = namespace, - name = name, - index_labels = index_labels, - hnsw_update_params = hnsw_update_params, - logger = logger, - timeout = timeout - ) - - try: - await index_stub.Update( - index_update_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to update index with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - - async def index_drop( - self, *, namespace: str, name: str, timeout: Optional[int] = None - ) -> None: - """ - Drop an index. - - :param namespace: The namespace of the index. - :type name: str - - :param name: The name of the index. - :type name: str - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to drop the index. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - Note: - This method drops an index with the specified parameters and waits for the index deletion to complete. - It waits for up to 100,000 seconds for the index deletion to complete. - """ - await self._channel_provider._is_ready() - - (index_stub, index_drop_request, kwargs) = self._prepare_index_drop( - namespace, name, timeout, logger - ) - - try: - await index_stub.Drop( - index_drop_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to drop index with error: %s", e) - raise types.AVSServerError(rpc_error=e) - try: - await self._wait_for_index_deletion( - namespace=namespace, name=name, timeout=100_000 - ) - except grpc.RpcError as e: - logger.error("Failed waiting for deletion with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - async def index_list( - self, timeout: Optional[int] = None, apply_defaults: Optional[bool] = True - ) -> list[IndexDefinition]: - """ - List all indices. - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - :param apply_defaults: Apply default values to parameters which are not set by user. Defaults to True. - :type apply_defaults: bool - - Returns: list[dict]: A list of indices. - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list the index. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - """ - await self._channel_provider._is_ready() - - (index_stub, index_list_request, kwargs) = self._prepare_index_list( - timeout, - logger, - apply_defaults, - ) - - try: - response = await index_stub.List( - index_list_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to list indexes with error: %s", e) - raise types.AVSServerError(rpc_error=e) - return self._respond_index_list(response) - - async def index_get( - self, - *, - namespace: str, - name: str, - timeout: Optional[int] = None, - apply_defaults: Optional[bool] = True, - ) -> IndexDefinition: - """ - Retrieve the information related with an index. - - :param namespace: The namespace of the index. - :type name: str - - :param name: The name of the index. - :type name: str - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - :param apply_defaults: Apply default values to parameters which are not set by user. Defaults to True. - :type apply_defaults: bool - - Returns: dict[str, Union[int, str]: Information about an index. - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get the index. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - await self._channel_provider._is_ready() - - (index_stub, index_get_request, kwargs) = self._prepare_index_get( - namespace, - name, - timeout, - logger, - apply_defaults, - ) - - try: - response = await index_stub.Get( - index_get_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to get index with error: %s", e) - raise types.AVSServerError(rpc_error=e) - return self._respond_index_get(response) - - async def index_get_status( - self, *, namespace: str, name: str, timeout: Optional[int] = None - ) -> types.IndexStatusResponse: - """ - Retrieve the number of records queued to be merged into an index. - - :param namespace: The namespace of the index. - :type namespace: str - - :param name: The name of the index. - :type name: str - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - Returns: IndexStatusResponse: object containing index status information. - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get the index status. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - Note: - This method retrieves the status of the specified index. If index_get_status is called the vector client puts some records into Aerospike Vector Search, - the records may not immediately begin to merge into the index. - - Warning: This API is subject to change. - """ - await self._channel_provider._is_ready() - - (index_stub, index_get_status_request, kwargs) = self._prepare_index_get_status( - namespace, name, timeout, logger - ) - - try: - response = await index_stub.GetStatus( - index_get_status_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - return fromIndexStatusResponse(response) - except grpc.RpcError as e: - logger.error("Failed to get index status with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - - - async def add_user( - self, - *, - username: str, - password: str, - roles: list[str], - timeout: Optional[int] = None, - ) -> None: - """ - Add role-based access AVS User to the AVS Server. - - :param username: Username for the new user. - :type username: str - - :param password: Password for the new user. - :type password: str - - :param roles: Roles for the new user. - :type password: list[str] - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to add a user. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - await self._channel_provider._is_ready() - - (user_admin_stub, add_user_request, kwargs) = self._prepare_add_user( - username, password, roles, timeout, logger - ) - - try: - await user_admin_stub.AddUser( - add_user_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to add user with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - async def update_credentials( - self, *, username: str, password: str, timeout: Optional[int] = None - ) -> None: - """ - Update AVS User credentials. - - :param username: Username of the user to update. - :type username: str - - :param password: New password for the user. - :type password: str - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to update a users credentials. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - await self._channel_provider._is_ready() - - (user_admin_stub, update_credentials_request, kwargs) = ( - self._prepare_update_credentials(username, password, timeout, logger) - ) - - try: - await user_admin_stub.UpdateCredentials( - update_credentials_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to update credentials with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - async def drop_user(self, *, username: str, timeout: Optional[int] = None) -> None: - """ - Drops AVS User from the AVS Server. - - :param username: Username of the user to drop. - :type username: str - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to drop a user - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - await self._channel_provider._is_ready() - - (user_admin_stub, drop_user_request, kwargs) = self._prepare_drop_user( - username, timeout, logger - ) - - try: - await user_admin_stub.DropUser( - drop_user_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to drop user with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - async def get_user( - self, *, username: str, timeout: Optional[int] = None - ) -> types.User: - """ - Retrieves AVS User information from the AVS Server. - - :param username: Username of the user to be retrieved. - :type username: str - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - return: types.User: AVS User - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get a user. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - await self._channel_provider._is_ready() - - (user_admin_stub, get_user_request, kwargs) = self._prepare_get_user( - username, timeout, logger - ) - - try: - response = await user_admin_stub.GetUser( - get_user_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to get user with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - return self._respond_get_user(response) - - async def list_users(self, timeout: Optional[int] = None) -> list[types.User]: - """ - List all users existing on the AVS Server. - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - return: list[types.User]: list of AVS Users - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list users. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - await self._channel_provider._is_ready() - - (user_admin_stub, list_users_request, kwargs) = self._prepare_list_users( - timeout, logger - ) - - try: - response = await user_admin_stub.ListUsers( - list_users_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to list user with error: %s", e) - raise types.AVSServerError(rpc_error=e) - return self._respond_list_users(response) - - async def grant_roles( - self, *, username: str, roles: list[str], timeout: Optional[int] = None - ) : - """ - Grant roles to existing AVS Users. - - :param username: Username of the user which will receive the roles. - :type username: str - - :param roles: Roles the specified user will receive. - :type roles: list[str] - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to grant roles. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - await self._channel_provider._is_ready() - - (user_admin_stub, grant_roles_request, kwargs) = self._prepare_grant_roles( - username, roles, timeout, logger - ) - - try: - await user_admin_stub.GrantRoles( - grant_roles_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to grant roles with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - async def revoke_roles( - self, *, username: str, roles: list[str], timeout: Optional[int] = None - ) : - """ - Revoke roles from existing AVS Users. - - :param username: Username of the user undergoing role removal. - :type username: str - - :param roles: Roles to be revoked. - :type roles: list[str] - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to revoke roles. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - await self._channel_provider._is_ready() - - (user_admin_stub, revoke_roles_request, kwargs) = self._prepare_revoke_roles( - username, roles, timeout, logger - ) - - try: - await user_admin_stub.RevokeRoles( - revoke_roles_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to revoke roles with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - async def list_roles(self, timeout: Optional[int] = None) -> list[Role]: - """ - list roles of existing AVS Users. - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - returns: list[str]: Roles available in the AVS Server. - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list roles. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - """ - await self._channel_provider._is_ready() - - (user_admin_stub, list_roles_request, kwargs) = self._prepare_list_roles( - timeout, logger - ) - - try: - response = await user_admin_stub.ListRoles( - list_roles_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to list roles with error: %s", e) - raise types.AVSServerError(rpc_error=e) - return self._respond_list_roles(response) - - async def _wait_for_index_creation( - self, - *, - namespace: str, - name: str, - timeout: int = sys.maxsize, - wait_interval: float = 0.1, - ) -> None: - """ - Wait for the index to be created. - """ - await self._channel_provider._is_ready() - - (index_stub, wait_interval, start_time, _, _, index_creation_request) = ( - self._prepare_wait_for_index_waiting(namespace, name, wait_interval) - ) - while True: - self._check_timeout(start_time, timeout) - try: - await index_stub.GetStatus( - index_creation_request, - credentials=self._channel_provider.get_token(), - ) - logger.debug("Index created successfully") - # Index has been created - return - except grpc.RpcError as e: - if e.code() == grpc.StatusCode.NOT_FOUND: - - # Wait for some more time. - await asyncio.sleep(wait_interval) - else: - logger.error("Failed waiting for index creation with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - async def _wait_for_index_deletion( - self, - *, - namespace: str, - name: str, - timeout: int = sys.maxsize, - wait_interval: float = 0.1, - ) -> None: - """ - Wait for the index to be deleted. - """ - await self._channel_provider._is_ready() - - # Wait interval between polling - (index_stub, wait_interval, start_time, _, _, index_deletion_request) = ( - self._prepare_wait_for_index_waiting(namespace, name, wait_interval) - ) - - while True: - self._check_timeout(start_time, timeout) - - try: - await index_stub.GetStatus( - index_deletion_request, - credentials=self._channel_provider.get_token(), - ) - # Wait for some more time. - await asyncio.sleep(wait_interval) - except grpc.RpcError as e: - if e.code() == grpc.StatusCode.NOT_FOUND: - logger.debug("Index deleted successfully") - # Index has been created - return - else: - logger.error("Failed waiting for index deletion with error: %s", e) - - raise types.AVSServerError(rpc_error=e) - - async def close(self): - """ - Close the Aerospike Vector Search Admin Client. - - This method closes gRPC channels connected to Aerospike Vector Search. - - Note: - This method should be called when the VectorDbAdminClient is no longer needed to release resources. - """ - await self._channel_provider.close() - - async def __aenter__(self): - """ - Enter an asynchronous context manager for the admin client. - - Returns: - VectorDbAdminlient: Aerospike Vector Search Admin Client instance. - """ - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """ - Exit an asynchronous context manager for the admin client. - """ - await self.close() diff --git a/src/aerospike_vector_search/aio/client.py b/src/aerospike_vector_search/aio/client.py index 79b92bf3..5aa12caa 100644 --- a/src/aerospike_vector_search/aio/client.py +++ b/src/aerospike_vector_search/aio/client.py @@ -9,12 +9,14 @@ from .. import types from .internal import channel_provider -from ..shared.client_helpers import BaseClient +from ..shared.client_helpers import BaseClient as BaseClientMixin +from ..shared.admin_helpers import BaseClient as AdminBaseClientMixin +from ..shared.conversions import fromIndexStatusResponse logger = logging.getLogger(__name__) -class Client(BaseClient): +class Client(BaseClientMixin, AdminBaseClientMixin): """ Aerospike Vector Search Asyncio Admin Client @@ -812,6 +814,692 @@ async def wait_for_index_completion( validation_count = 0 await asyncio.sleep(wait_interval_float) + async def index_create( + self, + *, + namespace: str, + name: str, + vector_field: str, + dimensions: int, + vector_distance_metric: types.VectorDistanceMetric = ( + types.VectorDistanceMetric.SQUARED_EUCLIDEAN + ), + sets: Optional[str] = None, + index_params: Optional[types.HnswParams] = None, + index_labels: Optional[dict[str, str]] = None, + index_storage: Optional[types.IndexStorage] = None, + timeout: Optional[int] = 100_000, + ) -> None: + """ + Create an index. + + :param namespace: The namespace for the index. + :type namespace: str + + :param name: The name of the index. + :type name: str + + :param vector_field: The name of the field containing vector data. + :type vector_field: str + + :param dimensions: The number of dimensions in the vector data. + :type dimensions: int + + :param vector_distance_metric: + The distance metric used to compare when performing a vector search. + Defaults to :class:`VectorDistanceMetric.SQUARED_EUCLIDEAN`. + :type vector_distance_metric: types.VectorDistanceMetric + + :param sets: The set used for the index. Defaults to None. + :type sets: Optional[str] + + :param index_params: (Optional[types.HnswParams], optional): Parameters used for tuning + vector search. Defaults to None. If index_params is None, then the default values + specified for :class:`types.HnswParams` will be used. + :type index_params: Optional[types.HnswParams] + + :param index_labels: Metadata associated with the index. Defaults to None. + :type index_labels: Optional[dict[str, str]] + + :param index_storage: Namespace and set where index overhead (non-vector data) is stored. + :type index_storage: Optional[types.IndexStorage] + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + Note: + This method creates an index with the specified parameters and waits for the index creation to complete. + It waits for up to 100,000 seconds or the specified timeout for the index creation to complete. + """ + + await self._channel_provider._is_ready() + + (index_stub, index_create_request, kwargs) = self._prepare_index_create( + namespace, + name, + vector_field, + dimensions, + vector_distance_metric, + sets, + index_params, + index_labels, + index_storage, + timeout, + logger, + ) + + try: + await index_stub.Create( + index_create_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to create index with error: %s", e) + raise types.AVSServerError(rpc_error=e) + try: + await self._wait_for_index_creation( + namespace=namespace, name=name, timeout=100_000 + ) + except grpc.RpcError as e: + logger.error("Failed waiting for creation with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + + async def index_update( + self, + *, + namespace: str, + name: str, + index_labels: Optional[dict[str, str]] = None, + hnsw_update_params: Optional[types.HnswIndexUpdate] = None, + timeout: Optional[int] = 100_000, + ) -> None: + """ + Update an existing index. + + :param namespace: The namespace for the index. + :type namespace: str + + :param name: The name of the index. + :type name: str + + :param index_labels: Optional labels associated with the index. Defaults to None. + :type index_labels: Optional[dict[str, str]] + + :param hnsw_update_params: Parameters for updating HNSW index settings. + :type hnsw_update_params: Optional[types.HnswIndexUpdate] + + :param timeout: Timeout in seconds for internal index update tasks. Defaults to 100_000. + :type timeout: int + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to update the index. + """ + + await self._channel_provider._is_ready() + + (index_stub, index_update_request, kwargs) = self._prepare_index_update( + namespace = namespace, + name = name, + index_labels = index_labels, + hnsw_update_params = hnsw_update_params, + logger = logger, + timeout = timeout + ) + + try: + await index_stub.Update( + index_update_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to update index with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + + async def index_drop( + self, *, namespace: str, name: str, timeout: Optional[int] = None + ) -> None: + """ + Drop an index. + + :param namespace: The namespace of the index. + :type name: str + + :param name: The name of the index. + :type name: str + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to drop the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + Note: + This method drops an index with the specified parameters and waits for the index deletion to complete. + It waits for up to 100,000 seconds for the index deletion to complete. + """ + await self._channel_provider._is_ready() + + (index_stub, index_drop_request, kwargs) = self._prepare_index_drop( + namespace, name, timeout, logger + ) + + try: + await index_stub.Drop( + index_drop_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to drop index with error: %s", e) + raise types.AVSServerError(rpc_error=e) + try: + await self._wait_for_index_deletion( + namespace=namespace, name=name, timeout=100_000 + ) + except grpc.RpcError as e: + logger.error("Failed waiting for deletion with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + async def index_list( + self, timeout: Optional[int] = None, apply_defaults: Optional[bool] = True + ) -> list[types.IndexDefinition]: + """ + List all indices. + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + :param apply_defaults: Apply default values to parameters which are not set by user. Defaults to True. + :type apply_defaults: bool + + Returns: list[dict]: A list of indices. + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + """ + await self._channel_provider._is_ready() + + (index_stub, index_list_request, kwargs) = self._prepare_index_list( + timeout, + logger, + apply_defaults, + ) + + try: + response = await index_stub.List( + index_list_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to list indexes with error: %s", e) + raise types.AVSServerError(rpc_error=e) + return self._respond_index_list(response) + + async def index_get( + self, + *, + namespace: str, + name: str, + timeout: Optional[int] = None, + apply_defaults: Optional[bool] = True, + ) -> types.IndexDefinition: + """ + Retrieve the information related with an index. + + :param namespace: The namespace of the index. + :type name: str + + :param name: The name of the index. + :type name: str + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + :param apply_defaults: Apply default values to parameters which are not set by user. Defaults to True. + :type apply_defaults: bool + + Returns: dict[str, Union[int, str]: Information about an index. + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + await self._channel_provider._is_ready() + + (index_stub, index_get_request, kwargs) = self._prepare_index_get( + namespace, + name, + timeout, + logger, + apply_defaults, + ) + + try: + response = await index_stub.Get( + index_get_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to get index with error: %s", e) + raise types.AVSServerError(rpc_error=e) + return self._respond_index_get(response) + + async def index_get_status( + self, *, namespace: str, name: str, timeout: Optional[int] = None + ) -> types.IndexStatusResponse: + """ + Retrieve the number of records queued to be merged into an index. + + :param namespace: The namespace of the index. + :type namespace: str + + :param name: The name of the index. + :type name: str + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + Returns: IndexStatusResponse: object containing index status information. + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get the index status. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + Note: + This method retrieves the status of the specified index. If index_get_status is called the vector client puts some records into Aerospike Vector Search, + the records may not immediately begin to merge into the index. + + Warning: This API is subject to change. + """ + await self._channel_provider._is_ready() + + (index_stub, index_get_status_request, kwargs) = self._prepare_index_get_status( + namespace, name, timeout, logger + ) + + try: + response = await index_stub.GetStatus( + index_get_status_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + return fromIndexStatusResponse(response) + except grpc.RpcError as e: + logger.error("Failed to get index status with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + + + async def add_user( + self, + *, + username: str, + password: str, + roles: list[str], + timeout: Optional[int] = None, + ) -> None: + """ + Add role-based access AVS User to the AVS Server. + + :param username: Username for the new user. + :type username: str + + :param password: Password for the new user. + :type password: str + + :param roles: Roles for the new user. + :type password: list[str] + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to add a user. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + await self._channel_provider._is_ready() + + (user_admin_stub, add_user_request, kwargs) = self._prepare_add_user( + username, password, roles, timeout, logger + ) + + try: + await user_admin_stub.AddUser( + add_user_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to add user with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + async def update_credentials( + self, *, username: str, password: str, timeout: Optional[int] = None + ) -> None: + """ + Update AVS User credentials. + + :param username: Username of the user to update. + :type username: str + + :param password: New password for the user. + :type password: str + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to update a users credentials. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + await self._channel_provider._is_ready() + + (user_admin_stub, update_credentials_request, kwargs) = ( + self._prepare_update_credentials(username, password, timeout, logger) + ) + + try: + await user_admin_stub.UpdateCredentials( + update_credentials_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to update credentials with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + async def drop_user(self, *, username: str, timeout: Optional[int] = None) -> None: + """ + Drops AVS User from the AVS Server. + + :param username: Username of the user to drop. + :type username: str + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to drop a user + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + await self._channel_provider._is_ready() + + (user_admin_stub, drop_user_request, kwargs) = self._prepare_drop_user( + username, timeout, logger + ) + + try: + await user_admin_stub.DropUser( + drop_user_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to drop user with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + async def get_user( + self, *, username: str, timeout: Optional[int] = None + ) -> types.User: + """ + Retrieves AVS User information from the AVS Server. + + :param username: Username of the user to be retrieved. + :type username: str + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + return: types.User: AVS User + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get a user. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + await self._channel_provider._is_ready() + + (user_admin_stub, get_user_request, kwargs) = self._prepare_get_user( + username, timeout, logger + ) + + try: + response = await user_admin_stub.GetUser( + get_user_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to get user with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + return self._respond_get_user(response) + + async def list_users(self, timeout: Optional[int] = None) -> list[types.User]: + """ + List all users existing on the AVS Server. + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + return: list[types.User]: list of AVS Users + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list users. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + await self._channel_provider._is_ready() + + (user_admin_stub, list_users_request, kwargs) = self._prepare_list_users( + timeout, logger + ) + + try: + response = await user_admin_stub.ListUsers( + list_users_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to list user with error: %s", e) + raise types.AVSServerError(rpc_error=e) + return self._respond_list_users(response) + + async def grant_roles( + self, *, username: str, roles: list[str], timeout: Optional[int] = None + ) : + """ + Grant roles to existing AVS Users. + + :param username: Username of the user which will receive the roles. + :type username: str + + :param roles: Roles the specified user will receive. + :type roles: list[str] + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to grant roles. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + await self._channel_provider._is_ready() + + (user_admin_stub, grant_roles_request, kwargs) = self._prepare_grant_roles( + username, roles, timeout, logger + ) + + try: + await user_admin_stub.GrantRoles( + grant_roles_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to grant roles with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + async def revoke_roles( + self, *, username: str, roles: list[str], timeout: Optional[int] = None + ) : + """ + Revoke roles from existing AVS Users. + + :param username: Username of the user undergoing role removal. + :type username: str + + :param roles: Roles to be revoked. + :type roles: list[str] + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to revoke roles. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + await self._channel_provider._is_ready() + + (user_admin_stub, revoke_roles_request, kwargs) = self._prepare_revoke_roles( + username, roles, timeout, logger + ) + + try: + await user_admin_stub.RevokeRoles( + revoke_roles_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to revoke roles with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + async def list_roles(self, timeout: Optional[int] = None) -> list[types.Role]: + """ + list roles of existing AVS Users. + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + returns: list[str]: Roles available in the AVS Server. + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list roles. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + """ + await self._channel_provider._is_ready() + + (user_admin_stub, list_roles_request, kwargs) = self._prepare_list_roles( + timeout, logger + ) + + try: + response = await user_admin_stub.ListRoles( + list_roles_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to list roles with error: %s", e) + raise types.AVSServerError(rpc_error=e) + return self._respond_list_roles(response) + + async def _wait_for_index_creation( + self, + *, + namespace: str, + name: str, + timeout: int = sys.maxsize, + wait_interval: float = 0.1, + ) -> None: + """ + Wait for the index to be created. + """ + await self._channel_provider._is_ready() + + (index_stub, wait_interval, start_time, _, _, index_creation_request) = ( + self._prepare_wait_for_index_waiting(namespace, name, wait_interval) + ) + while True: + self._check_timeout(start_time, timeout) + try: + await index_stub.GetStatus( + index_creation_request, + credentials=self._channel_provider.get_token(), + ) + logger.debug("Index created successfully") + # Index has been created + return + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.NOT_FOUND: + + # Wait for some more time. + await asyncio.sleep(wait_interval) + else: + logger.error("Failed waiting for index creation with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + async def _wait_for_index_deletion( + self, + *, + namespace: str, + name: str, + timeout: int = sys.maxsize, + wait_interval: float = 0.1, + ) -> None: + """ + Wait for the index to be deleted. + """ + await self._channel_provider._is_ready() + + # Wait interval between polling + (index_stub, wait_interval, start_time, _, _, index_deletion_request) = ( + self._prepare_wait_for_index_waiting(namespace, name, wait_interval) + ) + + while True: + self._check_timeout(start_time, timeout) + + try: + await index_stub.GetStatus( + index_deletion_request, + credentials=self._channel_provider.get_token(), + ) + # Wait for some more time. + await asyncio.sleep(wait_interval) + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.NOT_FOUND: + logger.debug("Index deleted successfully") + # Index has been created + return + else: + logger.error("Failed waiting for index deletion with error: %s", e) + + raise types.AVSServerError(rpc_error=e) + async def close(self): """ Close the Aerospike Vector Search Client. @@ -819,7 +1507,7 @@ async def close(self): This method closes gRPC channels connected to Aerospike Vector Search. Note: - This method should be called when the VectorDbAdminClient is no longer needed to release resources. + This method should be called when the client is no longer needed to release resources. """ await self._channel_provider.close() diff --git a/src/aerospike_vector_search/client.py b/src/aerospike_vector_search/client.py index b78b26a1..78ead024 100644 --- a/src/aerospike_vector_search/client.py +++ b/src/aerospike_vector_search/client.py @@ -799,7 +799,7 @@ def close(self): This method closes gRPC channels connected to Aerospike Vector Search. Note: - This method should be called when the VectorDbAdminClient is no longer needed to release resources. + This method should be called when the client is no longer needed to release resources. """ self._channel_provider.close() diff --git a/tests/rbac/aio/conftest.py b/tests/rbac/aio/conftest.py index 0c881c08..3db4ba34 100644 --- a/tests/rbac/aio/conftest.py +++ b/tests/rbac/aio/conftest.py @@ -1,7 +1,6 @@ import pytest import asyncio from aerospike_vector_search.aio import Client -from aerospike_vector_search.aio.admin import Client as AdminClient from aerospike_vector_search import types @@ -28,7 +27,7 @@ async def drop_all_indexes( with open(private_key, "rb") as f: private_key = f.read() - async with AdminClient( + async with Client( seeds=types.HostPort(host=host, port=port), is_loadbalancer=is_loadbalancer, username=username, @@ -48,7 +47,7 @@ async def drop_all_indexes( @pytest.fixture(scope="module") -async def session_rbac_admin_client( +async def session_rbac_client( username, password, root_certificate, @@ -70,7 +69,7 @@ async def session_rbac_admin_client( with open(private_key, "rb") as f: private_key = f.read() - client = AdminClient( + client = Client( seeds=types.HostPort(host=host, port=port), is_loadbalancer=is_loadbalancer, username=username, diff --git a/tests/rbac/aio/test_admin_client_add_user.py b/tests/rbac/aio/test_admin_client_add_user.py index f7bf2d6d..3fd549f3 100644 --- a/tests/rbac/aio/test_admin_client_add_user.py +++ b/tests/rbac/aio/test_admin_client_add_user.py @@ -25,12 +25,12 @@ def __init__( ), ], ) -async def test_add_user(session_rbac_admin_client, test_case): - await session_rbac_admin_client.add_user( +async def test_add_user(session_rbac_client, test_case): + await session_rbac_client.add_user( username=test_case.username, password=test_case.password, roles=test_case.roles ) - result = await session_rbac_admin_client.get_user(username=test_case.username) + result = await session_rbac_client.get_user(username=test_case.username) assert result.username == test_case.username @@ -57,12 +57,12 @@ async def test_add_user(session_rbac_admin_client, test_case): ), ], ) -async def test_add_user_with_roles(session_rbac_admin_client, test_case): - await session_rbac_admin_client.add_user( +async def test_add_user_with_roles(session_rbac_client, test_case): + await session_rbac_client.add_user( username=test_case.username, password=test_case.password, roles=test_case.roles ) - result = await session_rbac_admin_client.get_user(username=test_case.username) + result = await session_rbac_client.get_user(username=test_case.username) assert result.username == test_case.username diff --git a/tests/rbac/aio/test_admin_client_drop_user.py b/tests/rbac/aio/test_admin_client_drop_user.py index 449e9e96..5dc91975 100644 --- a/tests/rbac/aio/test_admin_client_drop_user.py +++ b/tests/rbac/aio/test_admin_client_drop_user.py @@ -24,13 +24,13 @@ def __init__( ), ], ) -async def test_drop_user(session_rbac_admin_client, test_case): - await session_rbac_admin_client.add_user( +async def test_drop_user(session_rbac_client, test_case): + await session_rbac_client.add_user( username=test_case.username, password=test_case.password, roles=None ) - await session_rbac_admin_client.drop_user( + await session_rbac_client.drop_user( username=test_case.username, ) with pytest.raises(AVSServerError) as e_info: - result = await session_rbac_admin_client.get_user(username=test_case.username) + result = await session_rbac_client.get_user(username=test_case.username) assert e_info.value.rpc_error.code() == grpc.StatusCode.NOT_FOUND diff --git a/tests/rbac/aio/test_admin_client_get_user.py b/tests/rbac/aio/test_admin_client_get_user.py index c253a0bb..d37d259c 100644 --- a/tests/rbac/aio/test_admin_client_get_user.py +++ b/tests/rbac/aio/test_admin_client_get_user.py @@ -22,12 +22,12 @@ def __init__( ), ], ) -async def test_get_user(session_rbac_admin_client, test_case): - await session_rbac_admin_client.add_user( +async def test_get_user(session_rbac_client, test_case): + await session_rbac_client.add_user( username=test_case.username, password=test_case.password, roles=None ) - result = await session_rbac_admin_client.get_user(username=test_case.username) + result = await session_rbac_client.get_user(username=test_case.username) assert result.username == test_case.username diff --git a/tests/rbac/aio/test_admin_client_grant_roles.py b/tests/rbac/aio/test_admin_client_grant_roles.py index cf03c0ea..0a48bb99 100644 --- a/tests/rbac/aio/test_admin_client_grant_roles.py +++ b/tests/rbac/aio/test_admin_client_grant_roles.py @@ -21,16 +21,16 @@ def __init__(self, *, username, password, roles, granted_roles): ), ], ) -async def test_grant_roles(session_rbac_admin_client, test_case): - await session_rbac_admin_client.add_user( +async def test_grant_roles(session_rbac_client, test_case): + await session_rbac_client.add_user( username=test_case.username, password=test_case.password, roles=test_case.roles ) - await session_rbac_admin_client.grant_roles( + await session_rbac_client.grant_roles( username=test_case.username, roles=test_case.granted_roles ) - result = await session_rbac_admin_client.get_user(username=test_case.username) + result = await session_rbac_client.get_user(username=test_case.username) assert result.username == test_case.username diff --git a/tests/rbac/aio/test_admin_client_list_roles.py b/tests/rbac/aio/test_admin_client_list_roles.py index 4179c75e..233363ed 100644 --- a/tests/rbac/aio/test_admin_client_list_roles.py +++ b/tests/rbac/aio/test_admin_client_list_roles.py @@ -25,11 +25,11 @@ def __init__( ), ], ) -async def test_list_roles(session_rbac_admin_client, test_case): - await session_rbac_admin_client.add_user( +async def test_list_roles(session_rbac_client, test_case): + await session_rbac_client.add_user( username=test_case.username, password=test_case.password, roles=test_case.roles ) - result = await session_rbac_admin_client.list_roles() + result = await session_rbac_client.list_roles() for role in result: assert role.id in test_case.roles diff --git a/tests/rbac/aio/test_admin_client_list_users.py b/tests/rbac/aio/test_admin_client_list_users.py index 02ac6e8c..7067bf5a 100644 --- a/tests/rbac/aio/test_admin_client_list_users.py +++ b/tests/rbac/aio/test_admin_client_list_users.py @@ -17,12 +17,12 @@ def __init__(self, *, username, password): ), ], ) -async def test_list_users(session_rbac_admin_client, test_case): - await session_rbac_admin_client.add_user( +async def test_list_users(session_rbac_client, test_case): + await session_rbac_client.add_user( username=test_case.username, password=test_case.password, roles=None ) - result = await session_rbac_admin_client.list_users() + result = await session_rbac_client.list_users() user_found = False for user in result: if user.username == test_case.username: diff --git a/tests/rbac/aio/test_admin_client_revoke_roles.py b/tests/rbac/aio/test_admin_client_revoke_roles.py index 7ca68a41..e6c8fcaf 100644 --- a/tests/rbac/aio/test_admin_client_revoke_roles.py +++ b/tests/rbac/aio/test_admin_client_revoke_roles.py @@ -21,16 +21,16 @@ def __init__(self, *, username, password, roles, revoked_roles): ), ], ) -async def test_revoke_roles(session_rbac_admin_client, test_case): - await session_rbac_admin_client.add_user( +async def test_revoke_roles(session_rbac_client, test_case): + await session_rbac_client.add_user( username=test_case.username, password=test_case.password, roles=test_case.roles ) - await session_rbac_admin_client.revoke_roles( + await session_rbac_client.revoke_roles( username=test_case.username, roles=test_case.roles ) - result = await session_rbac_admin_client.get_user(username=test_case.username) + result = await session_rbac_client.get_user(username=test_case.username) assert result.username == test_case.username diff --git a/tests/rbac/aio/test_admin_client_update_credentials.py b/tests/rbac/aio/test_admin_client_update_credentials.py index edffbcd6..91b0b2df 100644 --- a/tests/rbac/aio/test_admin_client_update_credentials.py +++ b/tests/rbac/aio/test_admin_client_update_credentials.py @@ -19,17 +19,17 @@ def __init__(self, *, username, old_password, new_password): ), ], ) -async def test_update_credentials(session_rbac_admin_client, test_case): - await session_rbac_admin_client.add_user( +async def test_update_credentials(session_rbac_client, test_case): + await session_rbac_client.add_user( username=test_case.username, password=test_case.old_password, roles=None ) - await session_rbac_admin_client.update_credentials( + await session_rbac_client.update_credentials( username=test_case.username, password=test_case.new_password, ) - result = await session_rbac_admin_client.get_user(username=test_case.username) + result = await session_rbac_client.get_user(username=test_case.username) assert result.username == test_case.username From 813d4b54b830630609709d55c9b5311bab18bcdc Mon Sep 17 00:00:00 2001 From: dylan Date: Wed, 8 Jan 2025 12:41:26 -0800 Subject: [PATCH 21/21] redo sync admin and standard client merge --- .../extensive_vector_search_tests.yml | 81 -- .github/workflows/integration_test.yml | 2 +- docs/admin.rst | 10 - docs/aio.rst | 1 - docs/aio_admin.rst | 10 - docs/index.rst | 11 +- docs/sync.rst | 1 - pyproject.toml | 2 +- src/aerospike_vector_search/__init__.py | 1 - src/aerospike_vector_search/admin.py | 756 ------------------ src/aerospike_vector_search/client.py | 655 ++++++++++++++- tests/rbac/sync/conftest.py | 7 +- tests/rbac/sync/test_admin_client_add_user.py | 12 +- .../rbac/sync/test_admin_client_drop_user.py | 8 +- tests/rbac/sync/test_admin_client_get_user.py | 6 +- .../sync/test_admin_client_grant_roles.py | 8 +- .../rbac/sync/test_admin_client_list_roles.py | 6 +- .../rbac/sync/test_admin_client_list_users.py | 6 +- .../sync/test_admin_client_revoke_roles.py | 8 +- .../test_admin_client_update_credentials.py | 8 +- tests/standard/conftest.py | 101 +-- .../test_admin_client_index_create.py | 74 +- .../standard/test_admin_client_index_drop.py | 10 +- tests/standard/test_admin_client_index_get.py | 12 +- .../test_admin_client_index_get_status.py | 10 +- .../standard/test_admin_client_index_list.py | 8 +- .../test_admin_client_index_update.py | 6 +- .../standard/test_extensive_vector_search.py | 416 ---------- .../standard/test_vector_client_is_indexed.py | 3 +- .../test_vector_client_search_by_key.py | 14 +- tests/standard/test_vector_search.py | 7 +- 31 files changed, 767 insertions(+), 1493 deletions(-) delete mode 100644 .github/workflows/extensive_vector_search_tests.yml delete mode 100644 docs/admin.rst delete mode 100644 docs/aio_admin.rst delete mode 100644 src/aerospike_vector_search/admin.py delete mode 100644 tests/standard/test_extensive_vector_search.py diff --git a/.github/workflows/extensive_vector_search_tests.yml b/.github/workflows/extensive_vector_search_tests.yml deleted file mode 100644 index 8c67ec47..00000000 --- a/.github/workflows/extensive_vector_search_tests.yml +++ /dev/null @@ -1,81 +0,0 @@ -name: Run long running vector search tests - -on: - push: - branches: - - main - -jobs: - test-exhaustive-vector-search: - runs-on: ubuntu-24.04 - continue-on-error: false - - - strategy: - matrix: - python-version: ["3.12"] - - - steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python setup.py - pip install -r requirements.txt - working-directory: tests - - - - name: Retrieve the secret and decode it to a file - env: - FEATURE_FILE: ${{ secrets.FEATURE_FILE }} - run: | - echo $FEATURE_FILE | base64 --decode > features.conf - working-directory: tests - - - name: Docker Login - uses: docker/login-action@v2 - with: - registry: aerospike.jfrog.io - username: ${{ secrets.JFROG_USERNAME }} - password: ${{ secrets.JFROG_PASSWORD }} - - - - name: Set up RANDFILE environment variable - run: echo "RANDFILE=$HOME/.rnd" >> $GITHUB_ENV - - - name: Create .rnd file if it doesn't exist - run: touch $HOME/.rnd - - - name: create config - run: | - assets/call_gen.sh - cat /etc/hosts - working-directory: tests - - - name: Run unit tests - run: | - - docker run -d --name aerospike-vector-search --network=host -p 5000:5000 -v $(pwd):/etc/aerospike-vector-search aerospike/aerospike-vector-search:1.0.0 - docker run -d --name aerospike -p 3000:3000 -v .:/etc/aerospike aerospike/aerospike-server-enterprise:latest - - sleep 5 - - python -m pytest standard -s --host 0.0.0.0 --port 5000 --extensive_vector_search -vs - - mv .coverage coverage_data - working-directory: tests - - - name: Upload test coverage - uses: actions/upload-artifact@v4 - with: - name: coverage_exhaustive_vector_search - path: tests/coverage_data diff --git a/.github/workflows/integration_test.yml b/.github/workflows/integration_test.yml index a08cc2dd..8994d327 100644 --- a/.github/workflows/integration_test.yml +++ b/.github/workflows/integration_test.yml @@ -674,7 +674,7 @@ jobs: - name: Build and install aerospike-vector-search run: | python -m build - pip install dist/aerospike_vector_search-3.1.0-py3-none-any.whl + pip install dist/aerospike_vector_search-4.0.0-py3-none-any.whl - name: Upload to Codecov uses: codecov/codecov-action@v4 diff --git a/docs/admin.rst b/docs/admin.rst deleted file mode 100644 index eea13706..00000000 --- a/docs/admin.rst +++ /dev/null @@ -1,10 +0,0 @@ -AdminClient -===================== - -This class is the admin client, designed to conduct AVS administrative operation such as creating indexes, querying index information, and dropping indexes. - - -.. autoclass:: aerospike_vector_search.admin.Client - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/aio.rst b/docs/aio.rst index 0ad1b286..1936fdda 100644 --- a/docs/aio.rst +++ b/docs/aio.rst @@ -8,5 +8,4 @@ This module contains clients with coroutine methods used for asynchronous progra :maxdepth: 2 :caption: Contents: - aio_admin aio_client diff --git a/docs/aio_admin.rst b/docs/aio_admin.rst deleted file mode 100644 index 56141d97..00000000 --- a/docs/aio_admin.rst +++ /dev/null @@ -1,10 +0,0 @@ -AdminClient -===================== - -This class is the admin client, designed to conduct AVS administrative operation such as creating indexes, querying index information, and dropping indexes. - - -.. autoclass:: aerospike_vector_search.aio.admin.Client - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/index.rst b/docs/index.rst index 35246464..ad9f54e6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -5,14 +5,11 @@ Welcome to Aerospike Vector Search Client for Python. -This package splits the client functionality into two separate clients. - -This standard client (Client) specializes in performing database operations with vector data. -Moreover, the standard client supports Hierarchical Navigable Small World (HNSW) vector searches, +This client (Client) specializes in performing database operations with vector data. +Moreover, the client supports Hierarchical Navigable Small World (HNSW) vector searches, allowing users to find vectors similar to a given query vector within an index. - -This admin client (AdminClient) is designed to conduct AVS administrative operation such -as creating indexes, querying index information, and dropping indexes. +administrative operations such as creating indexes, +querying index information, and dropping indexes are also supported. Please explore the modules below for more information on API usage and details. diff --git a/docs/sync.rst b/docs/sync.rst index 53c461ea..06d46bcf 100644 --- a/docs/sync.rst +++ b/docs/sync.rst @@ -8,5 +8,4 @@ This module contains clients with methods used for synchronous programming. :maxdepth: 2 :caption: Contents: - admin client diff --git a/pyproject.toml b/pyproject.toml index e0ede0d2..15c61cc4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Topic :: Database" ] -version = "3.1.0" +version = "4.0.0" requires-python = ">3.8" dependencies = [ "grpcio == 1.68.1", diff --git a/src/aerospike_vector_search/__init__.py b/src/aerospike_vector_search/__init__.py index cc942ecb..d08c2b3b 100644 --- a/src/aerospike_vector_search/__init__.py +++ b/src/aerospike_vector_search/__init__.py @@ -1,5 +1,4 @@ from .client import Client -from .admin import Client as AdminClient from .types import ( HostPort, Key, diff --git a/src/aerospike_vector_search/admin.py b/src/aerospike_vector_search/admin.py deleted file mode 100644 index a9638d09..00000000 --- a/src/aerospike_vector_search/admin.py +++ /dev/null @@ -1,756 +0,0 @@ -import logging -import sys -import time -from typing import Optional, Union - -import grpc - -from . import types -from .internal import channel_provider -from .shared.admin_helpers import BaseClient -from .shared.conversions import fromIndexStatusResponse -from .types import IndexDefinition, Role - -logger = logging.getLogger(__name__) - - -class Client(BaseClient): - """ - Aerospike Vector Search Admin Client - - This client is designed to conduct Aerospike Vector Search administrative operation such as creating indexes, querying index information, and dropping indexes. - - :param seeds: Defines the AVS nodes to which you want AVS to connect. AVS iterates through the seed nodes. After connecting to a node, AVS discovers all the nodes in the cluster. - :type seeds: Union[types.HostPort, tuple[types.HostPort, ...]] - - :param listener_name: An external (NATed) address and port combination that differs from the actual address and port where AVS is listening. Clients can access AVS on a node using the advertised listener address and port. Defaults to None. - :type listener_name: Optional[str] - - :param is_loadbalancer: If true, the first seed address will be treated as a load balancer node. Defaults to False. - :type is_loadbalancer: Optional[bool] - - :param service_config_path: Path to the service configuration file. Defaults to None. - :type service_config_path: Optional[str] - - :param username: Username for Role-Based Access. Defaults to None. - :type username: Optional[str] - - :param password: Password for Role-Based Access. Defaults to None. - :type password: Optional[str] - - :param root_certificate: The PEM-encoded root certificates as a byte string. Defaults to None. - :type root_certificate: Optional[list[bytes], bytes] - - :param certificate_chain: The PEM-encoded certificate chain as a byte string. Defaults to None. - :type certificate_chain: Optional[bytes] - - :param private_key: The PEM-encoded private key as a byte string. Defaults to None. - :type private_key: Optional[bytes] - - :raises AVSClientError: Raised when no seed host is provided. - - """ - - def __init__( - self, - *, - seeds: Union[types.HostPort, tuple[types.HostPort, ...]], - listener_name: Optional[str] = None, - is_loadbalancer: Optional[bool] = False, - username: Optional[str] = None, - password: Optional[str] = None, - root_certificate: Optional[Union[list[str], str]] = None, - certificate_chain: Optional[str] = None, - private_key: Optional[str] = None, - service_config_path: Optional[str] = None, - ssl_target_name_override: Optional[str] = None, - ) -> None: - seeds = self._prepare_seeds(seeds) - - self._channel_provider = channel_provider.ChannelProvider( - seeds, - listener_name, - is_loadbalancer, - username, - password, - root_certificate, - certificate_chain, - private_key, - service_config_path, - ssl_target_name_override, - ) - - def index_create( - self, - *, - namespace: str, - name: str, - vector_field: str, - dimensions: int, - vector_distance_metric: types.VectorDistanceMetric = ( - types.VectorDistanceMetric.SQUARED_EUCLIDEAN - ), - sets: Optional[str] = None, - index_params: Optional[types.HnswParams] = None, - index_labels: Optional[dict[str, str]] = None, - index_storage: Optional[types.IndexStorage] = None, - timeout: Optional[int] = 100_000, - ) -> None: - """ - Create an index. - - :param namespace: The namespace for the index. - :type namespace: str - - :param name: The name of the index. - :type name: str - - :param vector_field: The name of the field containing vector data. - :type vector_field: str - - :param dimensions: The number of dimensions in the vector data. - :type dimensions: int - - :param vector_distance_metric: - The distance metric used to compare when performing a vector search. - Defaults to :class:`VectorDistanceMetric.SQUARED_EUCLIDEAN`. - :type vector_distance_metric: types.VectorDistanceMetric - - :param sets: The set used for the index. Defaults to None. - :type sets: Optional[str] - - :param index_params: (Optional[types.HnswParams], optional): Parameters used for tuning - vector search. Defaults to None. If index_params is None, then the default values - specified for :class:`types.HnswParams` will be used. - :type index_params: Optional[types.HnswParams] - - :param index_labels: Metadata associated with the index. Defaults to None. - :type index_labels: Optional[dict[str, str]] - - :param index_storage: Namespace and set where index overhead (non-vector data) is stored. - :type index_storage: Optional[types.IndexStorage] - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - Note: - This method creates an index with the specified parameters and waits for the index creation to complete. - It waits for up to 100,000 seconds for the index creation to complete. - """ - - (index_stub, index_create_request, kwargs) = self._prepare_index_create( - namespace, - name, - vector_field, - dimensions, - vector_distance_metric, - sets, - index_params, - index_labels, - index_storage, - timeout, - logger, - ) - - try: - index_stub.Create( - index_create_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to create index with error: %s", e) - raise types.AVSServerError(rpc_error=e) - try: - self._wait_for_index_creation( - namespace=namespace, name=name, timeout=100_000 - ) - except grpc.RpcError as e: - logger.error("Failed waiting for creation with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - def index_update( - self, - *, - namespace: str, - name: str, - index_labels: Optional[dict[str, str]] = None, - hnsw_update_params: Optional[types.HnswIndexUpdate] = None, - timeout: Optional[int] = 100_000, - ) -> None: - """ - Update an existing index. - - :param namespace: The namespace for the index. - :type namespace: str - - :param name: The name of the index. - :type name: str - - :param index_labels: Optional labels associated with the index. Defaults to None. - :type index_labels: Optional[dict[str, str]] - - :param hnsw_update_params: Parameters for updating HNSW index settings. - :type hnsw_update_params: Optional[types.HnswIndexUpdate] - - :param timeout: Time in seconds (default 100_000) this operation will wait before raising an error. - :type timeout: int - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to update the index. - """ - (index_stub, index_update_request, kwargs) = self._prepare_index_update( - namespace = namespace, - name = name, - index_labels = index_labels, - hnsw_update_params = hnsw_update_params, - timeout = timeout, - logger = logger, - ) - - try: - index_stub.Update( - index_update_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to update index with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - - def index_drop( - self, *, namespace: str, name: str, timeout: Optional[int] = None - ) -> None: - """ - Drop an index. - - :param namespace: The namespace of the index. - :type name: str - - :param name: The name of the index. - :type name: str - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to drop the index. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - Note: - This method drops an index with the specified parameters and waits for the index deletion to complete. - It waits for up to 100,000 seconds for the index deletion to complete. - """ - - (index_stub, index_drop_request, kwargs) = self._prepare_index_drop( - namespace, name, timeout, logger - ) - - try: - index_stub.Drop( - index_drop_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to drop index with error: %s", e) - raise types.AVSServerError(rpc_error=e) - try: - self._wait_for_index_deletion( - namespace=namespace, name=name, timeout=100_000 - ) - except grpc.RpcError as e: - logger.error("Failed waiting for deletion with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - def index_list( - self, timeout: Optional[int] = None, apply_defaults: Optional[bool] = True - ) -> list[IndexDefinition]: - """ - List all indices. - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - :param apply_defaults: Apply default values to parameters which are not set by user. Defaults to True. - :type apply_defaults: bool - - Returns: list[dict]: A list of indices. - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list the index. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - """ - - (index_stub, index_list_request, kwargs) = self._prepare_index_list( - timeout, logger, apply_defaults - ) - - try: - response = index_stub.List( - index_list_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to list indexes with error: %s", e) - raise types.AVSServerError(rpc_error=e) - return self._respond_index_list(response) - - def index_get( - self, - *, - namespace: str, - name: str, - timeout: Optional[int] = None, - apply_defaults: Optional[bool] = True, - ) -> IndexDefinition: - """ - Retrieve the information related with an index. - - :param namespace: The namespace of the index. - :type namespace: str - - :param name: The name of the index. - :type name: str - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - :param apply_defaults: Apply default values to parameters which are not set by user. Defaults to True. - :type apply_defaults: bool - - Returns: dict[str, Union[int, str]: Information about an index. - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get the index. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - - (index_stub, index_get_request, kwargs) = self._prepare_index_get( - namespace, name, timeout, logger, apply_defaults - ) - - try: - response = index_stub.Get( - index_get_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to get index with error: %s", e) - raise types.AVSServerError(rpc_error=e) - return self._respond_index_get(response) - - def index_get_status( - self, *, namespace: str, name: str, timeout: Optional[int] = None - ) -> types.IndexStatusResponse: - """ - Retrieve the number of records queued to be merged into an index. - - :param namespace: The namespace of the index. - :type name: str - - :param name: The name of the index. - :type name: str - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - Returns: IndexStatusResponse: AVS response containing index status information. - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get the index status. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - Note: - This method retrieves the status of the specified index. If index_get_status is called the vector client puts some records into Aerospike Vector Search, - the records may not immediately begin to merge into the index. - - Warning: This API is subject to change. - """ - (index_stub, index_get_status_request, kwargs) = self._prepare_index_get_status( - namespace, name, timeout, logger - ) - - try: - response = index_stub.GetStatus( - index_get_status_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - return fromIndexStatusResponse(response) - except grpc.RpcError as e: - logger.error("Failed to get index status with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - - - - def add_user( - self, - *, - username: str, - password: str, - roles: list[str], - timeout: Optional[int] = None, - ) -> None: - """ - Add role-based access AVS User to the AVS Server. - - :param username: Username for the new user. - :type username: str - - :param password: Password for the new user. - :type password: str - - :param roles: Roles for the new user. - :type roles: list[str] - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to add a user. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - (user_admin_stub, add_user_request, kwargs) = self._prepare_add_user( - username, password, roles, timeout, logger - ) - - try: - user_admin_stub.AddUser( - add_user_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to add user with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - def update_credentials( - self, *, username: str, password: str, timeout: Optional[int] = None - ) -> None: - """ - Update AVS User credentials. - - :param username: Username of the user to update. - :type username: str - - :param password: New password for the user. - :type password: str - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to update a users credentials. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - (user_admin_stub, update_credentials_request, kwargs) = ( - self._prepare_update_credentials(username, password, timeout, logger) - ) - - try: - user_admin_stub.UpdateCredentials( - update_credentials_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to update credentials with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - def drop_user(self, *, username: str, timeout: Optional[int] = None) -> None: - """ - Drops AVS User from the AVS Server. - - :param username: Username of the user to drop. - :type username: str - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to drop a user - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - (user_admin_stub, drop_user_request, kwargs) = self._prepare_drop_user( - username, timeout, logger - ) - - try: - user_admin_stub.DropUser( - drop_user_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to drop user with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - def get_user(self, *, username: str, timeout: Optional[int] = None) -> types.User: - """ - Retrieves AVS User information from the AVS Server. - - :param username: Username of the user to be retrieved. - :type username: str - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - return: types.User: AVS User - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get a user. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - (user_admin_stub, get_user_request, kwargs) = self._prepare_get_user( - username, timeout, logger - ) - - try: - response = user_admin_stub.GetUser( - get_user_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to get user with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - return self._respond_get_user(response) - - def list_users(self, timeout: Optional[int] = None) -> list[types.User]: - """ - List all users existing on the AVS Server. - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - return: list[types.User]: list of AVS Users - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list users. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - (user_admin_stub, list_users_request, kwargs) = self._prepare_list_users( - timeout, logger - ) - - try: - response = user_admin_stub.ListUsers( - list_users_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to list user with error: %s", e) - raise types.AVSServerError(rpc_error=e) - return self._respond_list_users(response) - - def grant_roles( - self, *, username: str, roles: list[str], timeout: Optional[int] = None - ) -> None: - """ - Grant roles to existing AVS Users. - - :param username: Username of the user which will receive the roles. - :type username: str - - :param roles: Roles the specified user will receive. - :type roles: list[str] - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to grant roles. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - (user_admin_stub, grant_roles_request, kwargs) = self._prepare_grant_roles( - username, roles, timeout, logger - ) - - try: - user_admin_stub.GrantRoles( - grant_roles_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to grant roles with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - def revoke_roles( - self, *, username: str, roles: list[str], timeout: Optional[int] = None - ) -> None: - """ - Revoke roles from existing AVS Users. - - :param username: Username of the user undergoing role removal. - :type username: str - - :param roles: Roles to be revoked. - :type roles: list[str] - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to revoke roles. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - - """ - (user_admin_stub, revoke_roles_request, kwargs) = self._prepare_revoke_roles( - username, roles, timeout, logger - ) - - try: - user_admin_stub.RevokeRoles( - revoke_roles_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to revoke roles with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - def list_roles(self, timeout: Optional[int] = None) -> list[Role]: - """ - List roles available on the AVS server. - - :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. - :type timeout: int - - returns: list[str]: Roles available in the AVS Server. - - Raises: - AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list roles. - This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. - """ - (user_admin_stub, list_roles_request, kwargs) = self._prepare_list_roles( - timeout, logger - ) - - try: - response = user_admin_stub.ListRoles( - list_roles_request, - credentials=self._channel_provider.get_token(), - **kwargs, - ) - except grpc.RpcError as e: - logger.error("Failed to list roles with error: %s", e) - raise types.AVSServerError(rpc_error=e) - return self._respond_list_roles(response) - - def _wait_for_index_creation( - self, - *, - namespace: str, - name: str, - timeout: int = sys.maxsize, - wait_interval: float = 0.1, - ) -> None: - """ - Wait for the index to be created. - """ - - (index_stub, wait_interval, start_time, _, _, index_creation_request) = ( - self._prepare_wait_for_index_waiting(namespace, name, wait_interval) - ) - while True: - self._check_timeout(start_time, timeout) - try: - index_stub.GetStatus( - index_creation_request, - credentials=self._channel_provider.get_token(), - ) - logger.debug("Index created successfully") - # Index has been created - return - except grpc.RpcError as e: - if e.code() == grpc.StatusCode.NOT_FOUND: - - # Wait for some more time. - time.sleep(wait_interval) - else: - logger.error("Failed waiting for index creation with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - def _wait_for_index_deletion( - self, - *, - namespace: str, - name: str, - timeout: int = sys.maxsize, - wait_interval: float = 0.1, - ) -> None: - """ - Wait for the index to be deleted. - """ - - # Wait interval between polling - (index_stub, wait_interval, start_time, _, _, index_deletion_request) = ( - self._prepare_wait_for_index_waiting(namespace, name, wait_interval) - ) - - while True: - self._check_timeout(start_time, timeout) - - try: - index_stub.GetStatus( - index_deletion_request, - credentials=self._channel_provider.get_token(), - ) - # Wait for some more time. - time.sleep(wait_interval) - except grpc.RpcError as e: - if e.code() == grpc.StatusCode.NOT_FOUND: - logger.debug("Index deleted successfully") - # Index has been created - return - else: - logger.error("Failed waiting for index deletion with error: %s", e) - raise types.AVSServerError(rpc_error=e) - - def close(self): - """ - Close the Aerospike Vector Search Admin Client. - - This method closes gRPC channels connected to Aerospike Vector Search. - - Note: - This method should be called when the VectorDbAdminClient is no longer needed to release resources. - """ - self._channel_provider.close() - - def __enter__(self): - """ - Enter a context manager for the admin client. - - Returns: - VectorDbAdminClient: Aerospike Vector Search Admin Client instance. - """ - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """ - Exit a context manager for the admin client. - """ - self.close() diff --git a/src/aerospike_vector_search/client.py b/src/aerospike_vector_search/client.py index 78ead024..21d0db8e 100644 --- a/src/aerospike_vector_search/client.py +++ b/src/aerospike_vector_search/client.py @@ -9,12 +9,14 @@ from . import types from .internal import channel_provider -from .shared.client_helpers import BaseClient +from .shared.client_helpers import BaseClient as BaseClientMixin +from .shared.admin_helpers import BaseClient as AdminBaseClientMixin +from .shared.conversions import fromIndexStatusResponse logger = logging.getLogger(__name__) -class Client(BaseClient): +class Client(BaseClientMixin, AdminBaseClientMixin): """ Aerospike Vector Search Client @@ -792,6 +794,655 @@ def wait_for_index_completion( validation_count = 0 time.sleep(wait_interval_float) + def index_create( + self, + *, + namespace: str, + name: str, + vector_field: str, + dimensions: int, + vector_distance_metric: types.VectorDistanceMetric = ( + types.VectorDistanceMetric.SQUARED_EUCLIDEAN + ), + sets: Optional[str] = None, + index_params: Optional[types.HnswParams] = None, + index_labels: Optional[dict[str, str]] = None, + index_storage: Optional[types.IndexStorage] = None, + timeout: Optional[int] = 100_000, + ) -> None: + """ + Create an index. + + :param namespace: The namespace for the index. + :type namespace: str + + :param name: The name of the index. + :type name: str + + :param vector_field: The name of the field containing vector data. + :type vector_field: str + + :param dimensions: The number of dimensions in the vector data. + :type dimensions: int + + :param vector_distance_metric: + The distance metric used to compare when performing a vector search. + Defaults to :class:`VectorDistanceMetric.SQUARED_EUCLIDEAN`. + :type vector_distance_metric: types.VectorDistanceMetric + + :param sets: The set used for the index. Defaults to None. + :type sets: Optional[str] + + :param index_params: (Optional[types.HnswParams], optional): Parameters used for tuning + vector search. Defaults to None. If index_params is None, then the default values + specified for :class:`types.HnswParams` will be used. + :type index_params: Optional[types.HnswParams] + + :param index_labels: Metadata associated with the index. Defaults to None. + :type index_labels: Optional[dict[str, str]] + + :param index_storage: Namespace and set where index overhead (non-vector data) is stored. + :type index_storage: Optional[types.IndexStorage] + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to create the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + Note: + This method creates an index with the specified parameters and waits for the index creation to complete. + It waits for up to 100,000 seconds for the index creation to complete. + """ + + (index_stub, index_create_request, kwargs) = self._prepare_index_create( + namespace, + name, + vector_field, + dimensions, + vector_distance_metric, + sets, + index_params, + index_labels, + index_storage, + timeout, + logger, + ) + + try: + index_stub.Create( + index_create_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to create index with error: %s", e) + raise types.AVSServerError(rpc_error=e) + try: + self._wait_for_index_creation( + namespace=namespace, name=name, timeout=100_000 + ) + except grpc.RpcError as e: + logger.error("Failed waiting for creation with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + def index_update( + self, + *, + namespace: str, + name: str, + index_labels: Optional[dict[str, str]] = None, + hnsw_update_params: Optional[types.HnswIndexUpdate] = None, + timeout: Optional[int] = 100_000, + ) -> None: + """ + Update an existing index. + + :param namespace: The namespace for the index. + :type namespace: str + + :param name: The name of the index. + :type name: str + + :param index_labels: Optional labels associated with the index. Defaults to None. + :type index_labels: Optional[dict[str, str]] + + :param hnsw_update_params: Parameters for updating HNSW index settings. + :type hnsw_update_params: Optional[types.HnswIndexUpdate] + + :param timeout: Time in seconds (default 100_000) this operation will wait before raising an error. + :type timeout: int + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to update the index. + """ + (index_stub, index_update_request, kwargs) = self._prepare_index_update( + namespace = namespace, + name = name, + index_labels = index_labels, + hnsw_update_params = hnsw_update_params, + timeout = timeout, + logger = logger, + ) + + try: + index_stub.Update( + index_update_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to update index with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + + def index_drop( + self, *, namespace: str, name: str, timeout: Optional[int] = None + ) -> None: + """ + Drop an index. + + :param namespace: The namespace of the index. + :type name: str + + :param name: The name of the index. + :type name: str + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to drop the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + Note: + This method drops an index with the specified parameters and waits for the index deletion to complete. + It waits for up to 100,000 seconds for the index deletion to complete. + """ + + (index_stub, index_drop_request, kwargs) = self._prepare_index_drop( + namespace, name, timeout, logger + ) + + try: + index_stub.Drop( + index_drop_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to drop index with error: %s", e) + raise types.AVSServerError(rpc_error=e) + try: + self._wait_for_index_deletion( + namespace=namespace, name=name, timeout=100_000 + ) + except grpc.RpcError as e: + logger.error("Failed waiting for deletion with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + def index_list( + self, timeout: Optional[int] = None, apply_defaults: Optional[bool] = True + ) -> list[types.IndexDefinition]: + """ + List all indices. + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + :param apply_defaults: Apply default values to parameters which are not set by user. Defaults to True. + :type apply_defaults: bool + + Returns: list[dict]: A list of indices. + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + """ + + (index_stub, index_list_request, kwargs) = self._prepare_index_list( + timeout, logger, apply_defaults + ) + + try: + response = index_stub.List( + index_list_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to list indexes with error: %s", e) + raise types.AVSServerError(rpc_error=e) + return self._respond_index_list(response) + + def index_get( + self, + *, + namespace: str, + name: str, + timeout: Optional[int] = None, + apply_defaults: Optional[bool] = True, + ) -> types.IndexDefinition: + """ + Retrieve the information related with an index. + + :param namespace: The namespace of the index. + :type namespace: str + + :param name: The name of the index. + :type name: str + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + :param apply_defaults: Apply default values to parameters which are not set by user. Defaults to True. + :type apply_defaults: bool + + Returns: dict[str, Union[int, str]: Information about an index. + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get the index. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + + (index_stub, index_get_request, kwargs) = self._prepare_index_get( + namespace, name, timeout, logger, apply_defaults + ) + + try: + response = index_stub.Get( + index_get_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to get index with error: %s", e) + raise types.AVSServerError(rpc_error=e) + return self._respond_index_get(response) + + def index_get_status( + self, *, namespace: str, name: str, timeout: Optional[int] = None + ) -> types.IndexStatusResponse: + """ + Retrieve the number of records queued to be merged into an index. + + :param namespace: The namespace of the index. + :type name: str + + :param name: The name of the index. + :type name: str + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + Returns: IndexStatusResponse: AVS response containing index status information. + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get the index status. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + Note: + This method retrieves the status of the specified index. If index_get_status is called the vector client puts some records into Aerospike Vector Search, + the records may not immediately begin to merge into the index. + + Warning: This API is subject to change. + """ + (index_stub, index_get_status_request, kwargs) = self._prepare_index_get_status( + namespace, name, timeout, logger + ) + + try: + response = index_stub.GetStatus( + index_get_status_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + return fromIndexStatusResponse(response) + except grpc.RpcError as e: + logger.error("Failed to get index status with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + + + + def add_user( + self, + *, + username: str, + password: str, + roles: list[str], + timeout: Optional[int] = None, + ) -> None: + """ + Add role-based access AVS User to the AVS Server. + + :param username: Username for the new user. + :type username: str + + :param password: Password for the new user. + :type password: str + + :param roles: Roles for the new user. + :type roles: list[str] + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to add a user. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + (user_admin_stub, add_user_request, kwargs) = self._prepare_add_user( + username, password, roles, timeout, logger + ) + + try: + user_admin_stub.AddUser( + add_user_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to add user with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + def update_credentials( + self, *, username: str, password: str, timeout: Optional[int] = None + ) -> None: + """ + Update AVS User credentials. + + :param username: Username of the user to update. + :type username: str + + :param password: New password for the user. + :type password: str + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to update a users credentials. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + (user_admin_stub, update_credentials_request, kwargs) = ( + self._prepare_update_credentials(username, password, timeout, logger) + ) + + try: + user_admin_stub.UpdateCredentials( + update_credentials_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to update credentials with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + def drop_user(self, *, username: str, timeout: Optional[int] = None) -> None: + """ + Drops AVS User from the AVS Server. + + :param username: Username of the user to drop. + :type username: str + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to drop a user + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + (user_admin_stub, drop_user_request, kwargs) = self._prepare_drop_user( + username, timeout, logger + ) + + try: + user_admin_stub.DropUser( + drop_user_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to drop user with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + def get_user(self, *, username: str, timeout: Optional[int] = None) -> types.User: + """ + Retrieves AVS User information from the AVS Server. + + :param username: Username of the user to be retrieved. + :type username: str + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + return: types.User: AVS User + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to get a user. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + (user_admin_stub, get_user_request, kwargs) = self._prepare_get_user( + username, timeout, logger + ) + + try: + response = user_admin_stub.GetUser( + get_user_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to get user with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + return self._respond_get_user(response) + + def list_users(self, timeout: Optional[int] = None) -> list[types.User]: + """ + List all users existing on the AVS Server. + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + return: list[types.User]: list of AVS Users + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list users. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + (user_admin_stub, list_users_request, kwargs) = self._prepare_list_users( + timeout, logger + ) + + try: + response = user_admin_stub.ListUsers( + list_users_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to list user with error: %s", e) + raise types.AVSServerError(rpc_error=e) + return self._respond_list_users(response) + + def grant_roles( + self, *, username: str, roles: list[str], timeout: Optional[int] = None + ) -> None: + """ + Grant roles to existing AVS Users. + + :param username: Username of the user which will receive the roles. + :type username: str + + :param roles: Roles the specified user will receive. + :type roles: list[str] + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to grant roles. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + (user_admin_stub, grant_roles_request, kwargs) = self._prepare_grant_roles( + username, roles, timeout, logger + ) + + try: + user_admin_stub.GrantRoles( + grant_roles_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to grant roles with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + def revoke_roles( + self, *, username: str, roles: list[str], timeout: Optional[int] = None + ) -> None: + """ + Revoke roles from existing AVS Users. + + :param username: Username of the user undergoing role removal. + :type username: str + + :param roles: Roles to be revoked. + :type roles: list[str] + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to revoke roles. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + + """ + (user_admin_stub, revoke_roles_request, kwargs) = self._prepare_revoke_roles( + username, roles, timeout, logger + ) + + try: + user_admin_stub.RevokeRoles( + revoke_roles_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to revoke roles with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + def list_roles(self, timeout: Optional[int] = None) -> list[types.Role]: + """ + List roles available on the AVS server. + + :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. + :type timeout: int + + returns: list[str]: Roles available in the AVS Server. + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to list roles. + This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters. + """ + (user_admin_stub, list_roles_request, kwargs) = self._prepare_list_roles( + timeout, logger + ) + + try: + response = user_admin_stub.ListRoles( + list_roles_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to list roles with error: %s", e) + raise types.AVSServerError(rpc_error=e) + return self._respond_list_roles(response) + + def _wait_for_index_creation( + self, + *, + namespace: str, + name: str, + timeout: int = sys.maxsize, + wait_interval: float = 0.1, + ) -> None: + """ + Wait for the index to be created. + """ + + (index_stub, wait_interval, start_time, _, _, index_creation_request) = ( + self._prepare_wait_for_index_waiting(namespace, name, wait_interval) + ) + while True: + self._check_timeout(start_time, timeout) + try: + index_stub.GetStatus( + index_creation_request, + credentials=self._channel_provider.get_token(), + ) + logger.debug("Index created successfully") + # Index has been created + return + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.NOT_FOUND: + + # Wait for some more time. + time.sleep(wait_interval) + else: + logger.error("Failed waiting for index creation with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + def _wait_for_index_deletion( + self, + *, + namespace: str, + name: str, + timeout: int = sys.maxsize, + wait_interval: float = 0.1, + ) -> None: + """ + Wait for the index to be deleted. + """ + + # Wait interval between polling + (index_stub, wait_interval, start_time, _, _, index_deletion_request) = ( + self._prepare_wait_for_index_waiting(namespace, name, wait_interval) + ) + + while True: + self._check_timeout(start_time, timeout) + + try: + index_stub.GetStatus( + index_deletion_request, + credentials=self._channel_provider.get_token(), + ) + # Wait for some more time. + time.sleep(wait_interval) + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.NOT_FOUND: + logger.debug("Index deleted successfully") + # Index has been created + return + else: + logger.error("Failed waiting for index deletion with error: %s", e) + raise types.AVSServerError(rpc_error=e) + def close(self): """ Close the Aerospike Vector Search Client. diff --git a/tests/rbac/sync/conftest.py b/tests/rbac/sync/conftest.py index 5c986c56..5b8f98ed 100644 --- a/tests/rbac/sync/conftest.py +++ b/tests/rbac/sync/conftest.py @@ -1,7 +1,6 @@ import pytest from aerospike_vector_search import Client -from aerospike_vector_search.admin import Client as AdminClient from aerospike_vector_search import types @@ -29,7 +28,7 @@ def drop_all_indexes( with open(private_key, "rb") as f: private_key = f.read() - with AdminClient( + with Client( seeds=types.HostPort(host=host, port=port), is_loadbalancer=is_loadbalancer, username=username, @@ -47,7 +46,7 @@ def drop_all_indexes( @pytest.fixture(scope="module") -def session_rbac_admin_client( +def session_rbac_client( username, password, root_certificate, @@ -70,7 +69,7 @@ def session_rbac_admin_client( with open(private_key, "rb") as f: private_key = f.read() - client = AdminClient( + client = Client( seeds=types.HostPort(host=host, port=port), is_loadbalancer=is_loadbalancer, username=username, diff --git a/tests/rbac/sync/test_admin_client_add_user.py b/tests/rbac/sync/test_admin_client_add_user.py index aa5a27c6..381b80a6 100644 --- a/tests/rbac/sync/test_admin_client_add_user.py +++ b/tests/rbac/sync/test_admin_client_add_user.py @@ -25,12 +25,12 @@ def __init__( ), ], ) -def test_add_user(session_rbac_admin_client, test_case): - session_rbac_admin_client.add_user( +def test_add_user(session_rbac_client, test_case): + session_rbac_client.add_user( username=test_case.username, password=test_case.password, roles=test_case.roles ) - result = session_rbac_admin_client.get_user(username=test_case.username) + result = session_rbac_client.get_user(username=test_case.username) assert result.username == test_case.username @@ -57,12 +57,12 @@ def test_add_user(session_rbac_admin_client, test_case): ), ], ) -def test_add_user_with_roles(session_rbac_admin_client, test_case): - session_rbac_admin_client.add_user( +def test_add_user_with_roles(session_rbac_client, test_case): + session_rbac_client.add_user( username=test_case.username, password=test_case.password, roles=test_case.roles ) - result = session_rbac_admin_client.get_user(username=test_case.username) + result = session_rbac_client.get_user(username=test_case.username) assert result.username == test_case.username diff --git a/tests/rbac/sync/test_admin_client_drop_user.py b/tests/rbac/sync/test_admin_client_drop_user.py index a18bee71..e665836a 100644 --- a/tests/rbac/sync/test_admin_client_drop_user.py +++ b/tests/rbac/sync/test_admin_client_drop_user.py @@ -24,13 +24,13 @@ def __init__( ), ], ) -def test_drop_user(session_rbac_admin_client, test_case): - session_rbac_admin_client.add_user( +def test_drop_user(session_rbac_client, test_case): + session_rbac_client.add_user( username=test_case.username, password=test_case.password, roles=None ) - session_rbac_admin_client.drop_user( + session_rbac_client.drop_user( username=test_case.username, ) with pytest.raises(AVSServerError) as e_info: - result = session_rbac_admin_client.get_user(username=test_case.username) + result = session_rbac_client.get_user(username=test_case.username) assert e_info.value.rpc_error.code() == grpc.StatusCode.NOT_FOUND diff --git a/tests/rbac/sync/test_admin_client_get_user.py b/tests/rbac/sync/test_admin_client_get_user.py index 69c76a92..61bdc405 100644 --- a/tests/rbac/sync/test_admin_client_get_user.py +++ b/tests/rbac/sync/test_admin_client_get_user.py @@ -22,12 +22,12 @@ def __init__( ), ], ) -def test_get_user(session_rbac_admin_client, test_case): - session_rbac_admin_client.add_user( +def test_get_user(session_rbac_client, test_case): + session_rbac_client.add_user( username=test_case.username, password=test_case.password, roles=None ) - result = session_rbac_admin_client.get_user(username=test_case.username) + result = session_rbac_client.get_user(username=test_case.username) assert result.username == test_case.username diff --git a/tests/rbac/sync/test_admin_client_grant_roles.py b/tests/rbac/sync/test_admin_client_grant_roles.py index dc74bef0..8ce18884 100644 --- a/tests/rbac/sync/test_admin_client_grant_roles.py +++ b/tests/rbac/sync/test_admin_client_grant_roles.py @@ -21,16 +21,16 @@ def __init__(self, *, username, password, roles, granted_roles): ), ], ) -def test_grant_roles(session_rbac_admin_client, test_case): - session_rbac_admin_client.add_user( +def test_grant_roles(session_rbac_client, test_case): + session_rbac_client.add_user( username=test_case.username, password=test_case.password, roles=test_case.roles ) - session_rbac_admin_client.grant_roles( + session_rbac_client.grant_roles( username=test_case.username, roles=test_case.granted_roles ) - result = session_rbac_admin_client.get_user(username=test_case.username) + result = session_rbac_client.get_user(username=test_case.username) assert result.username == test_case.username diff --git a/tests/rbac/sync/test_admin_client_list_roles.py b/tests/rbac/sync/test_admin_client_list_roles.py index a3d1bf82..35030910 100644 --- a/tests/rbac/sync/test_admin_client_list_roles.py +++ b/tests/rbac/sync/test_admin_client_list_roles.py @@ -25,11 +25,11 @@ def __init__( ), ], ) -def test_list_roles(session_rbac_admin_client, test_case): - session_rbac_admin_client.add_user( +def test_list_roles(session_rbac_client, test_case): + session_rbac_client.add_user( username=test_case.username, password=test_case.password, roles=test_case.roles ) - result = session_rbac_admin_client.list_roles() + result = session_rbac_client.list_roles() for role in result: assert role.id in test_case.roles diff --git a/tests/rbac/sync/test_admin_client_list_users.py b/tests/rbac/sync/test_admin_client_list_users.py index 0b713250..d029472d 100644 --- a/tests/rbac/sync/test_admin_client_list_users.py +++ b/tests/rbac/sync/test_admin_client_list_users.py @@ -17,12 +17,12 @@ def __init__(self, *, username, password): ), ], ) -def test_list_users(session_rbac_admin_client, test_case): - session_rbac_admin_client.add_user( +def test_list_users(session_rbac_client, test_case): + session_rbac_client.add_user( username=test_case.username, password=test_case.password, roles=None ) - result = session_rbac_admin_client.list_users() + result = session_rbac_client.list_users() user_found = False for user in result: if user.username == test_case.username: diff --git a/tests/rbac/sync/test_admin_client_revoke_roles.py b/tests/rbac/sync/test_admin_client_revoke_roles.py index 0620fb24..04d7d704 100644 --- a/tests/rbac/sync/test_admin_client_revoke_roles.py +++ b/tests/rbac/sync/test_admin_client_revoke_roles.py @@ -21,16 +21,16 @@ def __init__(self, *, username, password, roles, revoked_roles): ), ], ) -def test_revoke_roles(session_rbac_admin_client, test_case): - session_rbac_admin_client.add_user( +def test_revoke_roles(session_rbac_client, test_case): + session_rbac_client.add_user( username=test_case.username, password=test_case.password, roles=test_case.roles ) - session_rbac_admin_client.revoke_roles( + session_rbac_client.revoke_roles( username=test_case.username, roles=test_case.roles ) - result = session_rbac_admin_client.get_user(username=test_case.username) + result = session_rbac_client.get_user(username=test_case.username) assert result.username == test_case.username diff --git a/tests/rbac/sync/test_admin_client_update_credentials.py b/tests/rbac/sync/test_admin_client_update_credentials.py index 3e2d7894..328c05b2 100644 --- a/tests/rbac/sync/test_admin_client_update_credentials.py +++ b/tests/rbac/sync/test_admin_client_update_credentials.py @@ -19,17 +19,17 @@ def __init__(self, *, username, old_password, new_password): ), ], ) -def test_update_credentials(session_rbac_admin_client, test_case): - session_rbac_admin_client.add_user( +def test_update_credentials(session_rbac_client, test_case): + session_rbac_client.add_user( username=test_case.username, password=test_case.old_password, roles=None ) - session_rbac_admin_client.update_credentials( + session_rbac_client.update_credentials( username=test_case.username, password=test_case.new_password, ) - result = session_rbac_admin_client.get_user(username=test_case.username) + result = session_rbac_client.get_user(username=test_case.username) assert result.username == test_case.username diff --git a/tests/standard/conftest.py b/tests/standard/conftest.py index f5590e64..94d14e32 100644 --- a/tests/standard/conftest.py +++ b/tests/standard/conftest.py @@ -5,8 +5,6 @@ from aerospike_vector_search import Client from aerospike_vector_search.aio import Client as AsyncClient -from aerospike_vector_search.admin import Client as AdminClient -from aerospike_vector_search.aio.admin import Client as AsyncAdminClient from aerospike_vector_search import types, AVSServerError from utils import gen_records, DEFAULT_NAMESPACE, DEFAULT_INDEX_DIMENSION, DEFAULT_VECTOR_FIELD @@ -63,7 +61,7 @@ def drop_all_indexes( with open(private_key, "rb") as f: private_key = f.read() - with AdminClient( + with Client( seeds=types.HostPort(host=host, port=port), is_loadbalancer=is_loadbalancer, username=username, @@ -129,30 +127,6 @@ async def new_wrapped_async_client( ) -async def new_wrapped_async_admin_client( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - is_loadbalancer, - ssl_target_name_override, - loop -): - return AsyncAdminClient( - seeds=types.HostPort(host=host, port=port), - is_loadbalancer=is_loadbalancer, - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override - ) - - class AsyncClientWrapper(): def __init__(self, client, loop): self.client = client @@ -177,63 +151,6 @@ def _run_async_task(self, task): raise RuntimeError("Event loop is not running") -@pytest.fixture(scope="module") -def session_admin_client( - username, - password, - root_certificate, - host, - port, - certificate_chain, - private_key, - is_loadbalancer, - ssl_target_name_override, - async_client, - event_loop, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - if async_client: - task = new_wrapped_async_admin_client( - host=host, - port=port, - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - is_loadbalancer=is_loadbalancer, - ssl_target_name_override=ssl_target_name_override, - loop=event_loop - ) - client = asyncio.run_coroutine_threadsafe(task, event_loop).result() - client = AsyncClientWrapper(client, event_loop) - else: - client = AdminClient( - seeds=types.HostPort(host=host, port=port), - is_loadbalancer=is_loadbalancer, - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override - ) - - yield client - client.close() - - @pytest.fixture(scope="module") def session_vector_client( username, @@ -298,12 +215,12 @@ def index_name(): @pytest.fixture(params=[DEFAULT_INDEX_ARGS]) -def index(session_admin_client, index_name, request): +def index(session_vector_client, index_name, request): args = request.param namespace = args.get("namespace", DEFAULT_NAMESPACE) vector_field = args.get("vector_field", DEFAULT_VECTOR_FIELD) dimensions = args.get("dimensions", DEFAULT_INDEX_DIMENSION) - session_admin_client.index_create( + session_vector_client.index_create( name = index_name, namespace = namespace, vector_field = vector_field, @@ -324,7 +241,7 @@ def index(session_admin_client, index_name, request): ) yield index_name try: - session_admin_client.index_drop(namespace=namespace, name=index_name) + session_vector_client.index_drop(namespace=namespace, name=index_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass @@ -411,11 +328,6 @@ def pytest_addoption(parser): action="store_true", help="Skip the test if latency is too low to effectively trigger timeout", ) - parser.addoption( - "--extensive_vector_search", - action="store_true", - help="Run extensive vector search testing", - ) parser.addoption( "--async", action="store_true", @@ -484,8 +396,3 @@ def is_loadbalancer(request): @pytest.fixture(scope="module", autouse=True) def with_latency(request): return request.config.getoption("--with_latency") - - -@pytest.fixture(scope="module", autouse=True) -def extensive_vector_search(request): - return request.config.getoption("--extensive_vector_search") diff --git a/tests/standard/test_admin_client_index_create.py b/tests/standard/test_admin_client_index_create.py index d869506f..bb64ae47 100644 --- a/tests/standard/test_admin_client_index_create.py +++ b/tests/standard/test_admin_client_index_create.py @@ -62,14 +62,14 @@ def __init__( ) ], ) -def test_index_create(session_admin_client, test_case, random_name): +def test_index_create(session_vector_client, test_case, random_name): try: - session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) + session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass - session_admin_client.index_create( + session_vector_client.index_create( namespace=test_case.namespace, name=random_name, vector_field=test_case.vector_field, @@ -82,7 +82,7 @@ def test_index_create(session_admin_client, test_case, random_name): timeout=test_case.timeout, ) - results = session_admin_client.index_list() + results = session_vector_client.index_list() found = False for result in results: if result["id"]["name"] == random_name: @@ -100,7 +100,7 @@ def test_index_create(session_admin_client, test_case, random_name): assert result["storage"]["namespace"] == test_case.namespace assert result["storage"]["set_name"] == random_name assert found == True - drop_specified_index(session_admin_client, test_case.namespace, random_name) + drop_specified_index(session_vector_client, test_case.namespace, random_name) #@given(random_name=index_strategy()) @@ -132,15 +132,15 @@ def test_index_create(session_admin_client, test_case, random_name): ), ], ) -def test_index_create_with_dimnesions(session_admin_client, test_case, random_name): +def test_index_create_with_dimnesions(session_vector_client, test_case, random_name): try: - session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) + session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass - session_admin_client.index_create( + session_vector_client.index_create( namespace=test_case.namespace, name=random_name, vector_field=test_case.vector_field, @@ -153,7 +153,7 @@ def test_index_create_with_dimnesions(session_admin_client, test_case, random_na timeout=test_case.timeout, ) - results = session_admin_client.index_list() + results = session_vector_client.index_list() found = False for result in results: @@ -174,7 +174,7 @@ def test_index_create_with_dimnesions(session_admin_client, test_case, random_na assert result["storage"]["set_name"] == random_name assert found == True - drop_specified_index(session_admin_client, test_case.namespace, random_name) + drop_specified_index(session_vector_client, test_case.namespace, random_name) #@given(random_name=index_strategy()) @@ -229,16 +229,16 @@ def test_index_create_with_dimnesions(session_admin_client, test_case, random_na ], ) def test_index_create_with_vector_distance_metric( - session_admin_client, test_case, random_name + session_vector_client, test_case, random_name ): try: - session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) + session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass - session_admin_client.index_create( + session_vector_client.index_create( namespace=test_case.namespace, name=random_name, vector_field=test_case.vector_field, @@ -250,7 +250,7 @@ def test_index_create_with_vector_distance_metric( index_storage=test_case.index_storage, timeout=test_case.timeout, ) - results = session_admin_client.index_list() + results = session_vector_client.index_list() found = False for result in results: if result["id"]["name"] == random_name: @@ -268,7 +268,7 @@ def test_index_create_with_vector_distance_metric( assert result["storage"]["namespace"] == test_case.namespace assert result["storage"]["set_name"] == random_name assert found == True - drop_specified_index(session_admin_client, test_case.namespace, random_name) + drop_specified_index(session_vector_client, test_case.namespace, random_name) #@given(random_name=index_strategy()) @@ -300,15 +300,15 @@ def test_index_create_with_vector_distance_metric( ), ], ) -def test_index_create_with_sets(session_admin_client, test_case, random_name): +def test_index_create_with_sets(session_vector_client, test_case, random_name): try: - session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) + session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass - session_admin_client.index_create( + session_vector_client.index_create( namespace=test_case.namespace, name=random_name, vector_field=test_case.vector_field, @@ -320,7 +320,7 @@ def test_index_create_with_sets(session_admin_client, test_case, random_name): index_storage=test_case.index_storage, timeout=test_case.timeout, ) - results = session_admin_client.index_list() + results = session_vector_client.index_list() found = False for result in results: if result["id"]["name"] == random_name: @@ -338,7 +338,7 @@ def test_index_create_with_sets(session_admin_client, test_case, random_name): assert result["storage"]["namespace"] == test_case.namespace assert result["storage"]["set_name"] == random_name assert found == True - drop_specified_index(session_admin_client, test_case.namespace, random_name) + drop_specified_index(session_vector_client, test_case.namespace, random_name) #@given(random_name=index_strategy()) @@ -429,13 +429,13 @@ def test_index_create_with_sets(session_admin_client, test_case, random_name): ), ], ) -def test_index_create_with_index_params(session_admin_client, test_case, random_name): +def test_index_create_with_index_params(session_vector_client, test_case, random_name): try: - session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) + session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass - session_admin_client.index_create( + session_vector_client.index_create( namespace=test_case.namespace, name=random_name, vector_field=test_case.vector_field, @@ -448,7 +448,7 @@ def test_index_create_with_index_params(session_admin_client, test_case, random_ timeout=test_case.timeout, ) - results = session_admin_client.index_list() + results = session_vector_client.index_list() found = False for result in results: if result["id"]["name"] == random_name: @@ -534,7 +534,7 @@ def test_index_create_with_index_params(session_admin_client, test_case, random_ assert result["storage"]["namespace"] == test_case.namespace assert result["storage"]["set_name"] == random_name assert found == True - drop_specified_index(session_admin_client, test_case.namespace, random_name) + drop_specified_index(session_vector_client, test_case.namespace, random_name) #@given(random_name=index_strategy()) @@ -555,13 +555,13 @@ def test_index_create_with_index_params(session_admin_client, test_case, random_ ) ], ) -def test_index_create_index_labels(session_admin_client, test_case, random_name): +def test_index_create_index_labels(session_vector_client, test_case, random_name): try: - session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) + session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass - session_admin_client.index_create( + session_vector_client.index_create( namespace=test_case.namespace, name=random_name, vector_field=test_case.vector_field, @@ -574,7 +574,7 @@ def test_index_create_index_labels(session_admin_client, test_case, random_name) timeout=test_case.timeout, ) - results = session_admin_client.index_list() + results = session_vector_client.index_list() found = False for result in results: if result["id"]["name"] == random_name: @@ -595,7 +595,7 @@ def test_index_create_index_labels(session_admin_client, test_case, random_name) assert result["storage"]["namespace"] == test_case.namespace assert result["storage"]["set_name"] == random_name assert found == True - drop_specified_index(session_admin_client, test_case.namespace, random_name) + drop_specified_index(session_vector_client, test_case.namespace, random_name) #@given(random_name=index_strategy()) @@ -616,13 +616,13 @@ def test_index_create_index_labels(session_admin_client, test_case, random_name) ), ], ) -def test_index_create_index_storage(session_admin_client, test_case, random_name): +def test_index_create_index_storage(session_vector_client, test_case, random_name): try: - session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) + session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass - session_admin_client.index_create( + session_vector_client.index_create( namespace=test_case.namespace, name=random_name, vector_field=test_case.vector_field, @@ -635,7 +635,7 @@ def test_index_create_index_storage(session_admin_client, test_case, random_name timeout=test_case.timeout, ) - results = session_admin_client.index_list() + results = session_vector_client.index_list() found = False for result in results: if result["id"]["name"] == random_name: @@ -674,20 +674,20 @@ def test_index_create_index_storage(session_admin_client, test_case, random_name ], ) def test_index_create_timeout( - session_admin_client, test_case, random_name, with_latency + session_vector_client, test_case, random_name, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") try: - session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) + session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass for i in range(10): try: - session_admin_client.index_create( + session_vector_client.index_create( namespace=test_case.namespace, name=random_name, vector_field=test_case.vector_field, diff --git a/tests/standard/test_admin_client_index_drop.py b/tests/standard/test_admin_client_index_drop.py index cf05a88f..a21f90c6 100644 --- a/tests/standard/test_admin_client_index_drop.py +++ b/tests/standard/test_admin_client_index_drop.py @@ -11,10 +11,10 @@ @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -def test_index_drop(session_admin_client, empty_test_case, index): - session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=index) +def test_index_drop(session_vector_client, empty_test_case, index): + session_vector_client.index_drop(namespace=DEFAULT_NAMESPACE, name=index) - result = session_admin_client.index_list() + result = session_vector_client.index_list() result = result for index in result: assert index["id"]["name"] != index @@ -24,7 +24,7 @@ def test_index_drop(session_admin_client, empty_test_case, index): #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) def test_index_drop_timeout( - session_admin_client, + session_vector_client, empty_test_case, index, with_latency @@ -34,7 +34,7 @@ def test_index_drop_timeout( for i in range(10): try: - session_admin_client.index_drop( + session_vector_client.index_drop( namespace=DEFAULT_NAMESPACE, name=index, timeout=0.0001 ) except AVSServerError as se: diff --git a/tests/standard/test_admin_client_index_get.py b/tests/standard/test_admin_client_index_get.py index c6a77e00..4412b57e 100644 --- a/tests/standard/test_admin_client_index_get.py +++ b/tests/standard/test_admin_client_index_get.py @@ -8,8 +8,8 @@ @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -def test_index_get(session_admin_client, empty_test_case, index): - result = session_admin_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=True) +def test_index_get(session_vector_client, empty_test_case, index): + result = session_vector_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=True) assert result["id"]["name"] == index assert result["id"]["namespace"] == DEFAULT_NAMESPACE @@ -47,9 +47,9 @@ def test_index_get(session_admin_client, empty_test_case, index): @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -async def test_index_get_no_defaults(session_admin_client, empty_test_case, index): +async def test_index_get_no_defaults(session_vector_client, empty_test_case, index): - result = session_admin_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=False) + result = session_vector_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=False) assert result["id"]["name"] == index assert result["id"]["namespace"] == DEFAULT_NAMESPACE @@ -89,14 +89,14 @@ async def test_index_get_no_defaults(session_admin_client, empty_test_case, inde #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) def test_index_get_timeout( - session_admin_client, empty_test_case, index, with_latency + session_vector_client, empty_test_case, index, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") for i in range(10): try: - result = session_admin_client.index_get( + result = session_vector_client.index_get( namespace=DEFAULT_NAMESPACE, name=index, timeout=0.0001 ) diff --git a/tests/standard/test_admin_client_index_get_status.py b/tests/standard/test_admin_client_index_get_status.py index 93fa8c83..016c0027 100644 --- a/tests/standard/test_admin_client_index_get_status.py +++ b/tests/standard/test_admin_client_index_get_status.py @@ -13,25 +13,25 @@ @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -def test_index_get_status(session_admin_client, empty_test_case, index): - result = session_admin_client.index_get_status(namespace=DEFAULT_NAMESPACE, name=index) +def test_index_get_status(session_vector_client, empty_test_case, index): + result = session_vector_client.index_get_status(namespace=DEFAULT_NAMESPACE, name=index) assert result.unmerged_record_count == 0 - drop_specified_index(session_admin_client, DEFAULT_NAMESPACE, index) + drop_specified_index(session_vector_client, DEFAULT_NAMESPACE, index) @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) def test_index_get_status_timeout( - session_admin_client, empty_test_case, index, with_latency + session_vector_client, empty_test_case, index, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") for i in range(10): try: - result = session_admin_client.index_get_status( + result = session_vector_client.index_get_status( namespace=DEFAULT_NAMESPACE, name=index, timeout=0.0001 ) except AVSServerError as se: diff --git a/tests/standard/test_admin_client_index_list.py b/tests/standard/test_admin_client_index_list.py index 55e3978b..9c6c51f0 100644 --- a/tests/standard/test_admin_client_index_list.py +++ b/tests/standard/test_admin_client_index_list.py @@ -9,8 +9,8 @@ @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -def test_index_list(session_admin_client, empty_test_case, index): - result = session_admin_client.index_list(apply_defaults=True) +def test_index_list(session_vector_client, empty_test_case, index): + result = session_vector_client.index_list(apply_defaults=True) assert len(result) > 0 for index in result: assert isinstance(index["id"]["name"], str) @@ -32,7 +32,7 @@ def test_index_list(session_admin_client, empty_test_case, index): #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) def test_index_list_timeout( - session_admin_client, empty_test_case, with_latency + session_vector_client, empty_test_case, with_latency ): if not with_latency: @@ -41,7 +41,7 @@ def test_index_list_timeout( for i in range(10): try: - result = session_admin_client.index_list(timeout=0.0001) + result = session_vector_client.index_list(timeout=0.0001) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.DEADLINE_EXCEEDED: diff --git a/tests/standard/test_admin_client_index_update.py b/tests/standard/test_admin_client_index_update.py index e5be9822..eece36e4 100644 --- a/tests/standard/test_admin_client_index_update.py +++ b/tests/standard/test_admin_client_index_update.py @@ -44,9 +44,9 @@ def __init__( ), ], ) -def test_index_update(session_admin_client, test_case, index): +def test_index_update(session_vector_client, test_case, index): # Update the index with parameters based on the test case - session_admin_client.index_update( + session_vector_client.index_update( namespace=DEFAULT_NAMESPACE, name=index, index_labels=test_case.update_labels, @@ -57,7 +57,7 @@ def test_index_update(session_admin_client, test_case, index): time.sleep(10) # Verify the update - result = session_admin_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=True) + result = session_vector_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=True) assert result, "Expected result to be non-empty but got an empty dictionary." assert result["id"]["namespace"] == DEFAULT_NAMESPACE diff --git a/tests/standard/test_extensive_vector_search.py b/tests/standard/test_extensive_vector_search.py deleted file mode 100644 index 0e9efc43..00000000 --- a/tests/standard/test_extensive_vector_search.py +++ /dev/null @@ -1,416 +0,0 @@ -import numpy as np -import pytest -import time -from aerospike_vector_search import types -from aerospike_vector_search import AVSServerError -import grpc - -dimensions = 128 -truth_vector_dimensions = 100 -base_vector_number = 10_000 -query_vector_number = 100 - - -def parse_sift_to_numpy_array(length, dim, byte_buffer, dtype): - numpy = np.empty((length,), dtype=object) - - record_length = (dim * 4) + 4 - - for i in range(length): - current_offset = i * record_length - begin = current_offset - vector_begin = current_offset + 4 - end = current_offset + record_length - if np.frombuffer(byte_buffer[begin:vector_begin], dtype=np.int32)[0] != dim: - raise Exception("Failed to parse byte buffer correctly") - numpy[i] = np.frombuffer(byte_buffer[vector_begin:end], dtype=dtype) - return numpy - - -@pytest.fixture -def base_numpy(): - base_filename = "siftsmall/siftsmall_base.fvecs" - with open(base_filename, "rb") as file: - base_bytes = bytearray(file.read()) - - base_numpy = parse_sift_to_numpy_array( - base_vector_number, dimensions, base_bytes, np.float32 - ) - - return base_numpy - - -@pytest.fixture -def truth_numpy(): - truth_filename = "siftsmall/siftsmall_groundtruth.ivecs" - with open(truth_filename, "rb") as file: - truth_bytes = bytearray(file.read()) - - truth_numpy = parse_sift_to_numpy_array( - query_vector_number, truth_vector_dimensions, truth_bytes, np.int32 - ) - - return truth_numpy - - -@pytest.fixture -def query_numpy(): - query_filename = "siftsmall/siftsmall_query.fvecs" - with open(query_filename, "rb") as file: - query_bytes = bytearray(file.read()) - - query_numpy = parse_sift_to_numpy_array( - query_vector_number, dimensions, query_bytes, np.float32 - ) - - return query_numpy - - -def put_vector(client, vector, j, set_name): - client.upsert( - namespace="test", - key=str(j), - record_data={"unit_test": vector}, - set_name=set_name, - ) - - -def get_vector(client, j, set_name): - result = client.get(namespace="test", key=str(j), set_name=set_name) - - -def vector_search(client, vector, name): - result = client.vector_search( - namespace="test", - index_name=name, - query=vector, - limit=100, - field_names=["unit_test"], - ) - return result - - -def vector_search_ef_80(client, vector, name): - result = client.vector_search( - namespace="test", - index_name=name, - query=vector, - limit=100, - field_names=["unit_test"], - search_params=types.HnswSearchParams(ef=80), - ) - return result - - -def grade_results( - base_numpy, - truth_numpy, - query_numpy, - session_vector_client, - session_admin_client, - name, -): - - # Vector search all query vectors - results = [] - count = 0 - for i in query_numpy: - if count % 2: - results.append(vector_search(session_vector_client, i, name)) - else: - results.append(vector_search_ef_80(session_vector_client, i, name)) - count += 1 - - # Get recall numbers for each query - recall_for_each_query = [] - for i, outside in enumerate(truth_numpy): - true_positive = 0 - false_negative = 0 - # Parse all fields for each neighbor into an array - field_list = [] - - for j, result in enumerate(results[i]): - field_list.append(result.fields["unit_test"]) - for j, index in enumerate(outside): - vector = base_numpy[index].tolist() - if vector in field_list: - true_positive = true_positive + 1 - else: - false_negative = false_negative + 1 - - recall = true_positive / (true_positive + false_negative) - recall_for_each_query.append(recall) - - # Calculate the sum of all values - recall_sum = sum(recall_for_each_query) - - # Calculate the average - average = recall_sum / len(recall_for_each_query) - - assert average > 0.95 - for recall in recall_for_each_query: - assert recall > 0.9 - - -def test_vector_search( - base_numpy, - truth_numpy, - query_numpy, - session_vector_client, - session_admin_client, - extensive_vector_search, -): - - if not extensive_vector_search: - pytest.skip("Extensive vector tests disabled") - - session_admin_client.index_create( - namespace="test", - name="demo1", - vector_field="unit_test", - dimensions=128, - ) - - for j, vector in enumerate(base_numpy): - put_vector(session_vector_client, vector, j, None) - - session_vector_client.wait_for_index_completion(namespace="test", name="demo1") - - # Wait for index completion isn't perfect - # give the index some extra time since accuracy is the point of this test - time.sleep(5) - - grade_results( - base_numpy, - truth_numpy, - query_numpy, - session_vector_client, - session_admin_client, - name="demo1", - ) - - -def test_vector_search_with_set_same_as_index( - base_numpy, - truth_numpy, - query_numpy, - session_vector_client, - session_admin_client, - extensive_vector_search, -): - - if not extensive_vector_search: - pytest.skip("Extensive vector tests disabled") - - session_admin_client.index_create( - namespace="test", - name="demo2", - sets="demo2", - vector_field="unit_test", - dimensions=128, - index_storage=types.IndexStorage(namespace="test", set_name="demo2"), - ) - - for j, vector in enumerate(base_numpy): - put_vector(session_vector_client, vector, j, "demo2") - - for j, vector in enumerate(base_numpy): - get_vector(session_vector_client, j, "demo2") - - session_vector_client.wait_for_index_completion(namespace="test", name="demo2") - - # Wait for index completion isn't perfect - # give the index some extra time since accuracy is the point of this test - time.sleep(5) - - grade_results( - base_numpy, - truth_numpy, - query_numpy, - session_vector_client, - session_admin_client, - name="demo2", - ) - - -def test_vector_search_with_set_different_than_name( - base_numpy, - truth_numpy, - query_numpy, - session_vector_client, - session_admin_client, - extensive_vector_search, -): - - if not extensive_vector_search: - pytest.skip("Extensive vector tests disabled") - - session_admin_client.index_create( - namespace="test", - name="demo3", - vector_field="unit_test", - dimensions=128, - sets="example1", - index_storage=types.IndexStorage(namespace="test", set_name="demo3"), - ) - - for j, vector in enumerate(base_numpy): - put_vector(session_vector_client, vector, j, "example1") - - # Wait for index completion isn't perfect - # give the index some extra time since accuracy is the point of this test - time.sleep(5) - - session_vector_client.wait_for_index_completion(namespace="test", name="demo3") - - grade_results( - base_numpy, - truth_numpy, - query_numpy, - session_vector_client, - session_admin_client, - name="demo3", - ) - - -def test_vector_search_with_index_storage_different_than_name( - base_numpy, - truth_numpy, - query_numpy, - session_vector_client, - session_admin_client, - extensive_vector_search, -): - - if not extensive_vector_search: - pytest.skip("Extensive vector tests disabled") - - session_admin_client.index_create( - namespace="test", - name="demo4", - vector_field="unit_test", - dimensions=128, - sets="demo4", - index_storage=types.IndexStorage(namespace="test", set_name="example2"), - ) - - for j, vector in enumerate(base_numpy): - put_vector(session_vector_client, vector, j, "demo4") - - session_vector_client.wait_for_index_completion(namespace="test", name="demo4") - - # Wait for index completion isn't perfect - # give the index some extra time since accuracy is the point of this test - time.sleep(5) - - grade_results( - base_numpy, - truth_numpy, - query_numpy, - session_vector_client, - session_admin_client, - name="demo4", - ) - - -def test_vector_search_with_index_storage_different_location( - base_numpy, - truth_numpy, - query_numpy, - session_vector_client, - session_admin_client, - extensive_vector_search, -): - - if not extensive_vector_search: - pytest.skip("Extensive vector tests disabled") - - session_admin_client.index_create( - namespace="test", - name="demo5", - vector_field="unit_test", - dimensions=128, - sets="example3", - index_storage=types.IndexStorage(namespace="test", set_name="example4"), - ) - - for j, vector in enumerate(base_numpy): - put_vector(session_vector_client, vector, j, "example3") - - session_vector_client.wait_for_index_completion(namespace="test", name="demo5") - - # Wait for index completion isn't perfect - # give the index some extra time since accuracy is the point of this test - time.sleep(5) - - grade_results( - base_numpy, - truth_numpy, - query_numpy, - session_vector_client, - session_admin_client, - name="demo5", - ) - - -def test_vector_search_with_separate_namespace( - base_numpy, - truth_numpy, - query_numpy, - session_vector_client, - session_admin_client, - extensive_vector_search, -): - - if not extensive_vector_search: - pytest.skip("Extensive vector tests disabled") - - session_admin_client.index_create( - namespace="test", - name="demo6", - vector_field="unit_test", - dimensions=128, - sets="demo6", - index_storage=types.IndexStorage(namespace="index_storage", set_name="demo6"), - ) - - for j, vector in enumerate(base_numpy): - put_vector(session_vector_client, vector, j, "demo6") - - session_vector_client.wait_for_index_completion(namespace="test", name="demo6") - - # Wait for index completion isn't perfect - # give the index some extra time since accuracy is the point of this test - time.sleep(5) - - grade_results( - base_numpy, - truth_numpy, - query_numpy, - session_vector_client, - session_admin_client, - name="demo6", - ) - - -def test_vector_vector_search_timeout( - session_vector_client, session_admin_client, with_latency -): - if not with_latency: - pytest.skip("Server latency too low to test timeout") - - for i in range(10): - try: - result = session_vector_client.vector_search( - namespace="test", - index_name="demo2", - query=[0, 1, 2], - limit=100, - field_names=["unit_test"], - timeout=0.0001, - ) - except AVSServerError as se: - if se.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: - assert se.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED - return - assert "In several attempts, the timeout did not happen" == "TEST FAIL" diff --git a/tests/standard/test_vector_client_is_indexed.py b/tests/standard/test_vector_client_is_indexed.py index b2434bda..5284fd6e 100644 --- a/tests/standard/test_vector_client_is_indexed.py +++ b/tests/standard/test_vector_client_is_indexed.py @@ -8,14 +8,13 @@ def test_vector_is_indexed( - session_admin_client, session_vector_client, index, record, ): # wait for the record to be indexed wait_for_index( - admin_client=session_admin_client, + admin_client=session_vector_client, namespace=DEFAULT_NAMESPACE, index=index ) diff --git a/tests/standard/test_vector_client_search_by_key.py b/tests/standard/test_vector_client_search_by_key.py index 5ea3df98..d0141023 100644 --- a/tests/standard/test_vector_client_search_by_key.py +++ b/tests/standard/test_vector_client_search_by_key.py @@ -265,11 +265,10 @@ def __init__( ) def test_vector_search_by_key( session_vector_client, - session_admin_client, test_case, ): - session_admin_client.index_create( + session_vector_client.index_create( namespace=test_case.search_namespace, name=test_case.index_name, vector_field=test_case.vector_field, @@ -298,7 +297,7 @@ def test_vector_search_by_key( ) wait_for_index( - admin_client=session_admin_client, + admin_client=session_vector_client, namespace=test_case.search_namespace, index=test_case.index_name, ) @@ -324,7 +323,7 @@ def test_vector_search_by_key( key=key, ) - session_admin_client.index_drop( + session_vector_client.index_drop( namespace=test_case.search_namespace, name=test_case.index_name, ) @@ -332,10 +331,9 @@ def test_vector_search_by_key( def test_vector_search_by_key_different_namespaces( session_vector_client, - session_admin_client, ): - session_admin_client.index_create( + session_vector_client.index_create( namespace="index_storage", name="diff_ns_idx", vector_field="vec", @@ -374,7 +372,7 @@ def test_vector_search_by_key_different_namespaces( ) wait_for_index( - admin_client=session_admin_client, + admin_client=session_vector_client, namespace="index_storage", index="diff_ns_idx", ) @@ -415,7 +413,7 @@ def test_vector_search_by_key_different_namespaces( key="search_for", ) - session_admin_client.index_drop( + session_vector_client.index_drop( namespace="index_storage", name="diff_ns_idx", ) \ No newline at end of file diff --git a/tests/standard/test_vector_search.py b/tests/standard/test_vector_search.py index a18414b4..82c6326f 100644 --- a/tests/standard/test_vector_search.py +++ b/tests/standard/test_vector_search.py @@ -99,11 +99,10 @@ def __init__( ) def test_vector_search( session_vector_client, - session_admin_client, test_case, ): - session_admin_client.index_create( + session_vector_client.index_create( namespace=test_case.namespace, name=test_case.index_name, vector_field=test_case.vector_field, @@ -132,7 +131,7 @@ def test_vector_search( ) wait_for_index( - admin_client=session_admin_client, + admin_client=session_vector_client, namespace=test_case.namespace, index=test_case.index_name, ) @@ -154,7 +153,7 @@ def test_vector_search( key=key, ) - session_admin_client.index_drop( + session_vector_client.index_drop( namespace=test_case.namespace, name=test_case.index_name, )