diff --git a/src/aerospike_vector_search/aio/admin.py b/src/aerospike_vector_search/aio/admin.py index eae2a817..aa57bcd1 100644 --- a/src/aerospike_vector_search/aio/admin.py +++ b/src/aerospike_vector_search/aio/admin.py @@ -386,7 +386,7 @@ async def index_get_status( Note: This method retrieves the status of the specified index. If index_get_status is called the vector client puts some records into Aerospike Vector Search, - the records may not immediately begin to merge into the index. To wait for all records to be merged into an index, use vector_client.wait_for_index_completion. + the records may not immediately begin to merge into the index. Warning: This API is subject to change. """ diff --git a/src/aerospike_vector_search/types.py b/src/aerospike_vector_search/types.py index 8cb7b71d..59e3c214 100644 --- a/src/aerospike_vector_search/types.py +++ b/src/aerospike_vector_search/types.py @@ -528,21 +528,36 @@ class HnswParams(object): """ Parameters for the Hierarchical Navigable Small World (HNSW) algorithm, used for approximate nearest neighbor search. - :param m: The number of bi-directional links created per level during construction. Larger 'm' values lead to higher recall but slower construction. Defaults to 16. + :param m: The number of bi-directional links created per level during construction. Larger 'm' values lead to higher recall but slower construction. Optional, Defaults to 16. :type m: Optional[int] - :param ef_construction: The size of the dynamic list for the nearest neighbors (candidates) during the index construction. Larger 'ef_construction' values lead to higher recall but slower construction. Defaults to 100. + :param ef_construction: The size of the dynamic list for the nearest neighbors (candidates) during the index construction. Larger 'ef_construction' values lead to higher recall but slower construction. Optional, Defaults to 100. :type ef_construction: Optional[int] - :param ef: The size of the dynamic list for the nearest neighbors (candidates) during the search phase. Larger 'ef' values lead to higher recall but slower search. Defaults to 100. + :param ef: The size of the dynamic list for the nearest neighbors (candidates) during the search phase. Larger 'ef' values lead to higher recall but slower search. Optional, Defaults to 100. :type ef: Optional[int] - :param batching_params: Parameters related to configuring batch processing, such as the maximum number of records per batch and batching interval. Defaults to HnswBatchingParams(). - :type batching_params: Optional[HnswBatchingParams] + :param batching_params: Parameters related to configuring batch processing, such as the maximum number of records per batch and batching interval. Optional, Defaults to HnswBatchingParams(). + :type batching_params: HnswBatchingParams - :param enable_vector_integrity_check: Verifies if the underlying vector has changed before returning the kANN result. + :param max_mem_queue_size: Maximum size of in-memory queue for inserted/updated vector records. Optional, Defaults to the corresponding config on the AVS Server. + :type max_mem_queue_size: Optional[int] + + :param index_caching_params: Parameters related to configuring caching for the HNSW index. Optional, Defaults to HnswCachingParams(). + :type index_caching_params: HnswCachingParams + + :param healer_params: Parameters related to configuring the HNSW index healer. Optional, Defaults to HnswHealerParams(). + :type healer_params: HnswHealerParams + + :param merge_params: Parameters related to configuring the merging of index records. Optional, Defaults to HnswIndexMergeParams(). + :type merge_params: HnswIndexMergeParams + + :param enable_vector_integrity_check: Verifies if the underlying vector has changed before returning the kANN result. Optional, Defaults to True. :type enable_vector_integrity_check: Optional[bool] + :param record_caching_params: Parameters related to configuring caching for vector records. Optional, Defaults to HnswCachingParams(). + :type record_caching_params: HnswCachingParams + """ def __init__( diff --git a/tests/standard/aio/aio_utils.py b/tests/standard/aio/aio_utils.py index 6dea216e..2f39646b 100644 --- a/tests/standard/aio/aio_utils.py +++ b/tests/standard/aio/aio_utils.py @@ -1,6 +1,10 @@ +import asyncio + + async def drop_specified_index(admin_client, namespace, name): await admin_client.index_drop(namespace=namespace, name=name) + def gen_records(count: int, vec_bin: str, vec_dim: int): num = 0 while num < count: @@ -10,3 +14,22 @@ def gen_records(count: int, vec_bin: str, vec_dim: int): ) yield key_and_rec num += 1 + + +async def wait_for_index(admin_client, namespace: str, index: str): + + verticies = 0 + unmerged_recs = 0 + + while verticies == 0 or unmerged_recs > 0: + status = await admin_client.index_get_status( + namespace=namespace, + name=index, + ) + + verticies = status.index_healer_vertices_valid + unmerged_recs = status.unmerged_record_count + + # print(verticies) + # print(unmerged_recs) + await asyncio.sleep(0.5) \ No newline at end of file diff --git a/tests/standard/aio/conftest.py b/tests/standard/aio/conftest.py index 97429f6d..354bba04 100644 --- a/tests/standard/aio/conftest.py +++ b/tests/standard/aio/conftest.py @@ -2,12 +2,14 @@ import pytest import random import string +import grpc from aerospike_vector_search.aio import Client from aerospike_vector_search.aio.admin import Client as AdminClient -from aerospike_vector_search import types +from aerospike_vector_search import types, AVSServerError from .aio_utils import gen_records +import utils #import logging #logger = logging.getLogger(__name__) @@ -15,9 +17,9 @@ # default test values -DEFAULT_NAMESPACE = "test" -DEFAULT_INDEX_DIMENSION = 128 -DEFAULT_VECTOR_FIELD = "vector" +DEFAULT_NAMESPACE = utils.DEFAULT_NAMESPACE +DEFAULT_INDEX_DIMENSION = utils.DEFAULT_INDEX_DIMENSION +DEFAULT_VECTOR_FIELD = utils.DEFAULT_VECTOR_FIELD DEFAULT_INDEX_ARGS = { "namespace": DEFAULT_NAMESPACE, "vector_field": DEFAULT_VECTOR_FIELD, @@ -199,14 +201,37 @@ def index_name(): @pytest.fixture(params=[DEFAULT_INDEX_ARGS]) async def index(session_admin_client, index_name, request): - index_args = request.param + args = request.param + namespace = args.get("namespace", DEFAULT_NAMESPACE) + vector_field = args.get("vector_field", DEFAULT_VECTOR_FIELD) + dimensions = args.get("dimensions", DEFAULT_INDEX_DIMENSION) await session_admin_client.index_create( name = index_name, - **index_args, + namespace = namespace, + vector_field = vector_field, + dimensions = dimensions, + index_params=types.HnswParams( + batching_params=types.HnswBatchingParams( + # 10_000 is the minimum value, in order for the tests to run as + # fast as possible we set it to the minimum value so records are indexed + # quickly + index_interval=10_000, + ), + healer_params=types.HnswHealerParams( + # run the healer every second + # for fast indexing + schedule="* * * * * ?" + ) + ) ) yield index_name - namespace = index_args.get("namespace", DEFAULT_NAMESPACE) - await session_admin_client.index_drop(namespace=namespace, name=index_name) + try: + await session_admin_client.index_drop(namespace=namespace, name=index_name) + except AVSServerError as se: + if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: + pass + else: + raise @pytest.fixture(params=[DEFAULT_RECORDS_ARGS]) @@ -217,10 +242,35 @@ async def records(session_vector_client, request): num_records = args.get("num_records", DEFAULT_NUM_RECORDS) vector_field = args.get("vector_field", DEFAULT_VECTOR_FIELD) dimensions = args.get("dimensions", DEFAULT_INDEX_DIMENSION) + set_name = args.get("set_name", None) keys = [] for key, rec in record_generator(count=num_records, vec_bin=vector_field, vec_dim=dimensions): - await session_vector_client.upsert(namespace=namespace, key=key, record_data=rec) + await session_vector_client.upsert( + namespace=namespace, + key=key, + record_data=rec, + set_name=set_name, + ) keys.append(key) - yield len(keys) + yield keys for key in keys: - await session_vector_client.delete(key=key, namespace=namespace) \ No newline at end of file + await session_vector_client.delete(key=key, namespace=namespace) + + +@pytest.fixture(params=[DEFAULT_RECORDS_ARGS]) +async def record(session_vector_client, request): + args = request.param + record_generator = args.get("record_generator", DEFAULT_RECORD_GENERATOR) + namespace = args.get("namespace", DEFAULT_NAMESPACE) + vector_field = args.get("vector_field", DEFAULT_VECTOR_FIELD) + dimensions = args.get("dimensions", DEFAULT_INDEX_DIMENSION) + set_name = args.get("set_name", None) + key, rec = next(record_generator(count=1, vec_bin=vector_field, vec_dim=dimensions)) + await session_vector_client.upsert( + namespace=namespace, + key=key, + record_data=rec, + set_name=set_name, + ) + yield key + await session_vector_client.delete(key=key, namespace=namespace) \ No newline at end of file diff --git a/tests/standard/aio/test_admin_client_index_create.py b/tests/standard/aio/test_admin_client_index_create.py index 251c9252..0897afca 100644 --- a/tests/standard/aio/test_admin_client_index_create.py +++ b/tests/standard/aio/test_admin_client_index_create.py @@ -2,7 +2,7 @@ from aerospike_vector_search import types, AVSServerError import grpc -from ...utils import random_name +from ...utils import random_name, DEFAULT_NAMESPACE from .aio_utils import drop_specified_index from hypothesis import given, settings, Verbosity, Phase @@ -51,7 +51,7 @@ def __init__( "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_1", dimensions=1024, vector_distance_metric=None, @@ -105,7 +105,7 @@ async def test_index_create(session_admin_client, test_case, random_name): "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_2", dimensions=495, vector_distance_metric=None, @@ -116,7 +116,7 @@ async def test_index_create(session_admin_client, test_case, random_name): timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_3", dimensions=2048, vector_distance_metric=None, @@ -174,7 +174,7 @@ async def test_index_create_with_dimnesions( "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_4", dimensions=1024, vector_distance_metric=types.VectorDistanceMetric.COSINE, @@ -185,7 +185,7 @@ async def test_index_create_with_dimnesions( timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_5", dimensions=1024, vector_distance_metric=types.VectorDistanceMetric.DOT_PRODUCT, @@ -196,7 +196,7 @@ async def test_index_create_with_dimnesions( timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_6", dimensions=1024, vector_distance_metric=types.VectorDistanceMetric.MANHATTAN, @@ -207,7 +207,7 @@ async def test_index_create_with_dimnesions( timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_7", dimensions=1024, vector_distance_metric=types.VectorDistanceMetric.HAMMING, @@ -262,7 +262,7 @@ async def test_index_create_with_vector_distance_metric( "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_8", dimensions=1024, vector_distance_metric=None, @@ -273,7 +273,7 @@ async def test_index_create_with_vector_distance_metric( timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_9", dimensions=1024, vector_distance_metric=None, @@ -326,7 +326,7 @@ async def test_index_create_with_sets(session_admin_client, test_case, random_na "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_10", dimensions=1024, vector_distance_metric=None, @@ -342,7 +342,7 @@ async def test_index_create_with_sets(session_admin_client, test_case, random_na timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_11", dimensions=1024, vector_distance_metric=None, @@ -358,7 +358,7 @@ async def test_index_create_with_sets(session_admin_client, test_case, random_na timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_12", dimensions=1024, vector_distance_metric=None, @@ -372,7 +372,7 @@ async def test_index_create_with_sets(session_admin_client, test_case, random_na timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_13", dimensions=1024, vector_distance_metric=None, @@ -386,7 +386,7 @@ async def test_index_create_with_sets(session_admin_client, test_case, random_na timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_20", dimensions=1024, vector_distance_metric=None, @@ -473,7 +473,7 @@ async def test_index_create_with_index_params( "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_14", dimensions=1024, vector_distance_metric=None, @@ -527,14 +527,14 @@ async def test_index_create_index_labels(session_admin_client, test_case, random "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_15", dimensions=1024, vector_distance_metric=None, sets=None, index_params=None, index_labels=None, - index_storage=types.IndexStorage(namespace="test", set_name="foo"), + index_storage=types.IndexStorage(namespace=DEFAULT_NAMESPACE, set_name="foo"), timeout=None, ), ], @@ -578,7 +578,7 @@ async def test_index_create_index_storage(session_admin_client, test_case, rando "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_16", dimensions=1024, vector_distance_metric=None, diff --git a/tests/standard/aio/test_admin_client_index_drop.py b/tests/standard/aio/test_admin_client_index_drop.py index b49e473b..9128a5c8 100644 --- a/tests/standard/aio/test_admin_client_index_drop.py +++ b/tests/standard/aio/test_admin_client_index_drop.py @@ -2,7 +2,7 @@ from aerospike_vector_search import AVSServerError import grpc -from ...utils import random_name +from utils import DEFAULT_NAMESPACE from hypothesis import given, settings, Verbosity @@ -11,40 +11,29 @@ @pytest.mark.parametrize("empty_test_case", [None, None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=2000) -async def test_index_drop(session_admin_client, empty_test_case, random_name): - await session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="art", - dimensions=1024, - ) - await session_admin_client.index_drop(namespace="test", name=random_name) +async def test_index_drop(session_admin_client, empty_test_case, index): + await session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=index) result = session_admin_client.index_list() result = await result for index in result: - assert index["id"]["name"] != random_name + assert index["id"]["name"] != index @pytest.mark.parametrize("empty_test_case", [None, None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) async def test_index_drop_timeout( - session_admin_client, empty_test_case, random_name, with_latency + session_admin_client, empty_test_case, index, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") - await session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="art", - dimensions=1024, - ) + with pytest.raises(AVSServerError) as e_info: for i in range(10): await session_admin_client.index_drop( - namespace="test", name=random_name, timeout=0.0001 + namespace=DEFAULT_NAMESPACE, name=index, timeout=0.0001 ) assert e_info.value.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED diff --git a/tests/standard/aio/test_admin_client_index_get.py b/tests/standard/aio/test_admin_client_index_get.py index 5ebeaf1f..ff21b3e9 100644 --- a/tests/standard/aio/test_admin_client_index_get.py +++ b/tests/standard/aio/test_admin_client_index_get.py @@ -1,7 +1,6 @@ import pytest -from ...utils import random_name +from ...utils import DEFAULT_NAMESPACE, DEFAULT_INDEX_DIMENSION, DEFAULT_VECTOR_FIELD -from .aio_utils import drop_specified_index from hypothesis import given, settings, Verbosity from aerospike_vector_search import AVSServerError @@ -11,29 +10,23 @@ @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -async def test_index_get(session_admin_client, empty_test_case, random_name): - await session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="science", - dimensions=1024, - ) - result = await session_admin_client.index_get(namespace="test", name=random_name, apply_defaults=True) - - assert result["id"]["name"] == random_name - assert result["id"]["namespace"] == "test" - assert result["dimensions"] == 1024 - assert result["field"] == "science" +async def test_index_get(session_admin_client, empty_test_case, index): + result = await session_admin_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=True) + + assert result["id"]["name"] == index + assert result["id"]["namespace"] == DEFAULT_NAMESPACE + assert result["dimensions"] == DEFAULT_INDEX_DIMENSION + assert result["field"] == DEFAULT_VECTOR_FIELD assert result["hnsw_params"]["m"] == 16 assert result["hnsw_params"]["ef_construction"] == 100 assert result["hnsw_params"]["ef"] == 100 assert result["hnsw_params"]["batching_params"]["max_index_records"] == 100000 - assert result["hnsw_params"]["batching_params"]["index_interval"] == 30000 + assert result["hnsw_params"]["batching_params"]["index_interval"] == 10000 assert result["hnsw_params"]["batching_params"]["max_reindex_records"] == max(100000 / 10, 1000) assert result["hnsw_params"]["batching_params"]["reindex_interval"] == 30000 - assert result["storage"]["namespace"] == "test" - assert result["storage"].set_name == random_name - assert result["storage"]["set_name"] == random_name + assert result["storage"]["namespace"] == DEFAULT_NAMESPACE + assert result["storage"].set_name == index + assert result["storage"]["set_name"] == index # Defaults assert result["sets"] == "" @@ -46,33 +39,24 @@ async def test_index_get(session_admin_client, empty_test_case, random_name): assert result["hnsw_params"]["healer_params"]["max_scan_rate_per_node"] == 1000 assert result["hnsw_params"]["healer_params"]["max_scan_page_size"] == 10000 assert result["hnsw_params"]["healer_params"]["re_index_percent"] == 10.0 - assert result["hnsw_params"]["healer_params"]["schedule"] == "0 0/15 * ? * * *" + assert result["hnsw_params"]["healer_params"]["schedule"] == "* * * * * ?" assert result["hnsw_params"]["healer_params"]["parallelism"] == 1 # index parallelism and reindex parallelism are dynamic depending on the CPU cores of the host # assert result["hnsw_params"]["merge_params"]["index_parallelism"] == 80 # assert result["hnsw_params"]["merge_params"]["reindex_parallelism"] == 26 - await drop_specified_index(session_admin_client, "test", random_name) - @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -async def test_index_get_no_defaults(session_admin_client, empty_test_case, random_name): - await session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="science", - dimensions=1024, - ) +async def test_index_get_no_defaults(session_admin_client, empty_test_case, index): + result = await session_admin_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=False) - result = await session_admin_client.index_get(namespace="test", name=random_name, apply_defaults=False) - - assert result["id"]["name"] == random_name - assert result["id"]["namespace"] == "test" - assert result["dimensions"] == 1024 - assert result["field"] == "science" + assert result["id"]["name"] == index + assert result["id"]["namespace"] == DEFAULT_NAMESPACE + assert result["dimensions"] == DEFAULT_INDEX_DIMENSION + assert result["field"] == DEFAULT_VECTOR_FIELD # Defaults assert result["sets"] == "" @@ -82,7 +66,8 @@ async def test_index_get_no_defaults(session_admin_client, empty_test_case, rand assert result["hnsw_params"]["ef"] == 0 assert result["hnsw_params"]["ef_construction"] == 0 assert result["hnsw_params"]["batching_params"]["max_index_records"] == 0 - assert result["hnsw_params"]["batching_params"]["index_interval"] == 0 + # This is set by default to 10000 in the index fixture + assert result["hnsw_params"]["batching_params"]["index_interval"] == 10000 assert result["hnsw_params"]["max_mem_queue_size"] == 0 assert result["hnsw_params"]["index_caching_params"]["max_entries"] == 0 assert result["hnsw_params"]["index_caching_params"]["expiry"] == 0 @@ -90,7 +75,8 @@ async def test_index_get_no_defaults(session_admin_client, empty_test_case, rand assert result["hnsw_params"]["healer_params"]["max_scan_rate_per_node"] == 0 assert result["hnsw_params"]["healer_params"]["max_scan_page_size"] == 0 assert result["hnsw_params"]["healer_params"]["re_index_percent"] == 0 - assert result["hnsw_params"]["healer_params"]["schedule"] == "" + # This is set by default to * * * * * ? in the index fixture + assert result["hnsw_params"]["healer_params"]["schedule"] == "* * * * * ?" assert result["hnsw_params"]["healer_params"]["parallelism"] == 0 assert result["hnsw_params"]["merge_params"]["index_parallelism"] == 0 @@ -100,14 +86,12 @@ async def test_index_get_no_defaults(session_admin_client, empty_test_case, rand assert result["storage"].set_name == "" assert result["storage"]["set_name"] == "" - await drop_specified_index(session_admin_client, "test", random_name) - @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) async def test_index_get_timeout( - session_admin_client, empty_test_case, random_name, with_latency + session_admin_client, empty_test_case, index, with_latency ): if not with_latency: @@ -116,7 +100,7 @@ async def test_index_get_timeout( for i in range(10): try: result = await session_admin_client.index_get( - namespace="test", name=random_name, timeout=0.0001 + namespace=DEFAULT_NAMESPACE, name=index, timeout=0.0001 ) except AVSServerError as se: if se.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: diff --git a/tests/standard/aio/test_admin_client_index_get_status.py b/tests/standard/aio/test_admin_client_index_get_status.py index 3f8bd4cb..fec3fc21 100644 --- a/tests/standard/aio/test_admin_client_index_get_status.py +++ b/tests/standard/aio/test_admin_client_index_get_status.py @@ -1,7 +1,6 @@ import pytest -from ...utils import random_name +from ...utils import DEFAULT_NAMESPACE -from .aio_utils import drop_specified_index from hypothesis import given, settings, Verbosity from aerospike_vector_search import types, AVSServerError @@ -11,46 +10,26 @@ @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -async def test_index_get_status(session_admin_client, empty_test_case, random_name): - try: - await session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="science", - dimensions=1024, - ) - except Exception as e: - pass - +async def test_index_get_status(session_admin_client, empty_test_case, index): result : types.IndexStatusResponse = await session_admin_client.index_get_status( - namespace="test", name=random_name + namespace=DEFAULT_NAMESPACE, name=index ) assert result.unmerged_record_count == 0 - await drop_specified_index(session_admin_client, "test", random_name) @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) async def test_index_get_status_timeout( - session_admin_client, empty_test_case, random_name, with_latency + session_admin_client, empty_test_case, index, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") - try: - await session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="science", - dimensions=1024, - ) - except Exception as e: - pass for i in range(10): try: result = await session_admin_client.index_get_status( - namespace="test", name=random_name, timeout=0.0001 + namespace=DEFAULT_NAMESPACE, name=index, timeout=0.0001 ) except AVSServerError as se: if se.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: diff --git a/tests/standard/aio/test_admin_client_index_list.py b/tests/standard/aio/test_admin_client_index_list.py index 9304320b..af633e8c 100644 --- a/tests/standard/aio/test_admin_client_index_list.py +++ b/tests/standard/aio/test_admin_client_index_list.py @@ -1,26 +1,14 @@ -import pytest - from aerospike_vector_search import AVSServerError import pytest import grpc - -from ...utils import random_name - -from .aio_utils import drop_specified_index from hypothesis import given, settings, Verbosity @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -async def test_index_list(session_admin_client, empty_test_case, random_name): - await session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="science", - dimensions=1024, - ) +async def test_index_list(session_admin_client, empty_test_case, index): result = await session_admin_client.index_list(apply_defaults=True) assert len(result) > 0 for index in result: @@ -37,7 +25,6 @@ async def test_index_list(session_admin_client, empty_test_case, random_name): assert isinstance(index["hnsw_params"]["batching_params"]["reindex_interval"], int) assert isinstance(index["storage"]["namespace"], str) assert isinstance(index["storage"]["set_name"], str) - await drop_specified_index(session_admin_client, "test", random_name) async def test_index_list_timeout(session_admin_client, with_latency): diff --git a/tests/standard/aio/test_admin_client_index_update.py b/tests/standard/aio/test_admin_client_index_update.py index 6e6466d1..734b1b77 100644 --- a/tests/standard/aio/test_admin_client_index_update.py +++ b/tests/standard/aio/test_admin_client_index_update.py @@ -1,9 +1,9 @@ import time -import pytest -from aerospike_vector_search import types, AVSServerError -import grpc -from .aio_utils import drop_specified_index +from aerospike_vector_search import types +from utils import DEFAULT_NAMESPACE + +import pytest server_defaults = { "m": 16, @@ -20,7 +20,6 @@ class index_update_test_case: def __init__( self, *, - namespace, vector_field, dimensions, initial_labels, @@ -28,7 +27,6 @@ def __init__( hnsw_index_update, timeout ): - self.namespace = namespace self.vector_field = vector_field self.dimensions = dimensions self.initial_labels = initial_labels @@ -41,7 +39,6 @@ def __init__( "test_case", [ index_update_test_case( - namespace="test", vector_field="update_2", dimensions=256, initial_labels={"status": "active"}, @@ -63,31 +60,11 @@ def __init__( ), ], ) -async def test_index_update_async(session_admin_client, test_case): - # Create a unique index name for each test run - trimmed_random = "aBEd-1" - - # Drop any pre-existing index with the same name - try: - session_admin_client.index_drop(namespace="test", name=trimmed_random) - except AVSServerError as se: - if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: - pass - - # Create the index - await session_admin_client.index_create( - namespace=test_case.namespace, - name=trimmed_random, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - index_labels=test_case.initial_labels, - timeout=test_case.timeout, - ) - +async def test_index_update_async(session_admin_client, test_case, index): # Update the index with new labels and parameters await session_admin_client.index_update( - namespace=test_case.namespace, - name=trimmed_random, + namespace=DEFAULT_NAMESPACE, + name=index, index_labels=test_case.update_labels, hnsw_update_params=test_case.hnsw_index_update ) @@ -96,10 +73,10 @@ async def test_index_update_async(session_admin_client, test_case): time.sleep(10) # Verify the update - result = await session_admin_client.index_get(namespace=test_case.namespace, name=trimmed_random, apply_defaults=True) + result = await session_admin_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=True) assert result, "Expected result to be non-empty but got an empty dictionary." - assert result["id"]["namespace"] == test_case.namespace + assert result["id"]["namespace"] == DEFAULT_NAMESPACE # Assertions based on provided parameters if test_case.hnsw_index_update.batching_params: @@ -131,6 +108,3 @@ async def test_index_update_async(session_admin_client, test_case): "max_scan_rate_per_node"] == test_case.hnsw_index_update.healer_params.max_scan_rate_per_node assert result["hnsw_params"]["enable_vector_integrity_check"] == test_case.hnsw_index_update.enable_vector_integrity_check - - # Clean up by dropping the index after the test - await drop_specified_index(session_admin_client, test_case.namespace, trimmed_random) diff --git a/tests/standard/aio/test_service_config.py b/tests/standard/aio/test_service_config.py index 66277e57..bcc94eda 100644 --- a/tests/standard/aio/test_service_config.py +++ b/tests/standard/aio/test_service_config.py @@ -1,496 +1,496 @@ -import pytest -import time - -import os -import json - -from aerospike_vector_search import AVSServerError, types -from aerospike_vector_search.aio import AdminClient - - -class service_config_parse_test_case: - def __init__(self, *, service_config_path): - self.service_config_path = service_config_path - - -@pytest.mark.parametrize( - "test_case", - [ - service_config_parse_test_case( - service_config_path="service_configs/master.json" - ), - ], -) -async def test_admin_client_service_config_parse( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - test_case, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - async with AdminClient( - seeds=types.HostPort(host=host, port=port), - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override, - service_config_path=test_case.service_config_path, - ) as client: - pass - - -class service_config_test_case: - def __init__( - self, *, service_config_path, namespace, name, vector_field, dimensions - ): - - script_dir = os.path.dirname(os.path.abspath(__file__)) - - self.service_config_path = os.path.abspath( - os.path.join(script_dir, "..", "..", service_config_path) - ) - - with open(self.service_config_path, "rb") as f: - self.service_config = json.load(f) - - self.max_attempts = self.service_config["methodConfig"][0]["retryPolicy"][ - "maxAttempts" - ] - self.initial_backoff = int( - self.service_config["methodConfig"][0]["retryPolicy"]["initialBackoff"][:-1] - ) - self.max_backoff = int( - self.service_config["methodConfig"][0]["retryPolicy"]["maxBackoff"][:-1] - ) - self.backoff_multiplier = self.service_config["methodConfig"][0]["retryPolicy"][ - "backoffMultiplier" - ] - self.retryable_status_codes = self.service_config["methodConfig"][0][ - "retryPolicy" - ]["retryableStatusCodes"] - self.namespace = namespace - self.name = name - self.vector_field = vector_field - self.dimensions = dimensions - - -def calculate_expected_time( - max_attempts, - initial_backoff, - backoff_multiplier, - max_backoff, - retryable_status_codes, -): - - current_backkoff = initial_backoff - - expected_time = 0 - for attempt in range(max_attempts - 1): - expected_time += current_backkoff - current_backkoff *= backoff_multiplier - current_backkoff = min(current_backkoff, max_backoff) - - return expected_time - - -@pytest.mark.parametrize( - "test_case", - [ - service_config_test_case( - service_config_path="service_configs/retries.json", - namespace="test", - name="service_config_index_1", - vector_field="example_1", - dimensions=1024, - ) - ], -) -async def test_admin_client_service_config_retries( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - test_case, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - async with AdminClient( - seeds=types.HostPort(host=host, port=port), - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override, - service_config_path=test_case.service_config_path, - ) as client: - try: - await client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - except: - pass - expected_time = calculate_expected_time( - test_case.max_attempts, - test_case.initial_backoff, - test_case.backoff_multiplier, - test_case.max_backoff, - test_case.retryable_status_codes, - ) - start_time = time.time() - - with pytest.raises(AVSServerError) as e_info: - await client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - - end_time = time.time() - elapsed_time = end_time - start_time - - assert abs(elapsed_time - expected_time) < 1.5 - - -@pytest.mark.parametrize( - "test_case", - [ - service_config_test_case( - service_config_path="service_configs/initial_backoff.json", - namespace="test", - name="service_config_index_2", - vector_field="example_1", - dimensions=1024, - ) - ], -) -async def test_admin_client_service_config_initial_backoff( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - test_case, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - async with AdminClient( - seeds=types.HostPort(host=host, port=port), - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override, - service_config_path=test_case.service_config_path, - ) as client: - - try: - await client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - except: - pass - - expected_time = calculate_expected_time( - test_case.max_attempts, - test_case.initial_backoff, - test_case.backoff_multiplier, - test_case.max_backoff, - test_case.retryable_status_codes, - ) - start_time = time.time() - - with pytest.raises(AVSServerError) as e_info: - await client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - - end_time = time.time() - elapsed_time = end_time - start_time - - assert abs(elapsed_time - expected_time) < 1.5 - - -@pytest.mark.parametrize( - "test_case", - [ - service_config_test_case( - service_config_path="service_configs/max_backoff.json", - namespace="test", - name="service_config_index_3", - vector_field="example_1", - dimensions=1024, - ), - service_config_test_case( - service_config_path="service_configs/max_backoff_lower_than_initial.json", - namespace="test", - name="service_config_index_4", - vector_field="example_1", - dimensions=1024, - ), - ], -) -async def test_admin_client_service_config_max_backoff( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - test_case, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - async with AdminClient( - seeds=types.HostPort(host=host, port=port), - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override, - service_config_path=test_case.service_config_path, - ) as client: - - try: - await client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - except: - pass - expected_time = calculate_expected_time( - test_case.max_attempts, - test_case.initial_backoff, - test_case.backoff_multiplier, - test_case.max_backoff, - test_case.retryable_status_codes, - ) - start_time = time.time() - - with pytest.raises(AVSServerError) as e_info: - await client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - - end_time = time.time() - elapsed_time = end_time - start_time - assert abs(elapsed_time - expected_time) < 1.5 - - -@pytest.mark.parametrize( - "test_case", - [ - service_config_test_case( - service_config_path="service_configs/backoff_multiplier.json", - namespace="test", - name="service_config_index_5", - vector_field="example_1", - dimensions=1024, - ) - ], -) -async def test_admin_client_service_config_backoff_multiplier( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - test_case, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - async with AdminClient( - seeds=types.HostPort(host=host, port=port), - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override, - service_config_path=test_case.service_config_path, - ) as client: - - try: - - await client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - except: - pass - - expected_time = calculate_expected_time( - test_case.max_attempts, - test_case.initial_backoff, - test_case.backoff_multiplier, - test_case.max_backoff, - test_case.retryable_status_codes, - ) - start_time = time.time() - - with pytest.raises(AVSServerError) as e_info: - await client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - - end_time = time.time() - elapsed_time = end_time - start_time - assert abs(elapsed_time - expected_time) < 1.5 - - -@pytest.mark.parametrize( - "test_case", - [ - service_config_test_case( - service_config_path="service_configs/retryable_status_codes.json", - namespace="test", - name="service_config_index_6", - vector_field=None, - dimensions=None, - ) - ], -) -async def test_admin_client_service_config_retryable_status_codes( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - test_case, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - async with AdminClient( - seeds=types.HostPort(host=host, port=port), - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override, - service_config_path=test_case.service_config_path, - ) as client: - - expected_time = calculate_expected_time( - test_case.max_attempts, - test_case.initial_backoff, - test_case.backoff_multiplier, - test_case.max_backoff, - test_case.retryable_status_codes, - ) - start_time = time.time() - - with pytest.raises(AVSServerError) as e_info: - await client.index_get_status( - namespace=test_case.namespace, - name=test_case.name, - ) - - end_time = time.time() - elapsed_time = end_time - start_time - assert abs(elapsed_time - expected_time) < 2 +# import pytest +# import time + +# import os +# import json + +# from aerospike_vector_search import AVSServerError, types +# from aerospike_vector_search.aio import AdminClient + + +# class service_config_parse_test_case: +# def __init__(self, *, service_config_path): +# self.service_config_path = service_config_path + + +# @pytest.mark.parametrize( +# "test_case", +# [ +# service_config_parse_test_case( +# service_config_path="service_configs/master.json" +# ), +# ], +# ) +# async def test_admin_client_service_config_parse( +# host, +# port, +# username, +# password, +# root_certificate, +# certificate_chain, +# private_key, +# ssl_target_name_override, +# test_case, +# ): + +# if root_certificate: +# with open(root_certificate, "rb") as f: +# root_certificate = f.read() + +# if certificate_chain: +# with open(certificate_chain, "rb") as f: +# certificate_chain = f.read() +# if private_key: +# with open(private_key, "rb") as f: +# private_key = f.read() + +# async with AdminClient( +# seeds=types.HostPort(host=host, port=port), +# username=username, +# password=password, +# root_certificate=root_certificate, +# certificate_chain=certificate_chain, +# private_key=private_key, +# ssl_target_name_override=ssl_target_name_override, +# service_config_path=test_case.service_config_path, +# ) as client: +# pass + + +# class service_config_test_case: +# def __init__( +# self, *, service_config_path, namespace, name, vector_field, dimensions +# ): + +# script_dir = os.path.dirname(os.path.abspath(__file__)) + +# self.service_config_path = os.path.abspath( +# os.path.join(script_dir, "..", "..", service_config_path) +# ) + +# with open(self.service_config_path, "rb") as f: +# self.service_config = json.load(f) + +# self.max_attempts = self.service_config["methodConfig"][0]["retryPolicy"][ +# "maxAttempts" +# ] +# self.initial_backoff = int( +# self.service_config["methodConfig"][0]["retryPolicy"]["initialBackoff"][:-1] +# ) +# self.max_backoff = int( +# self.service_config["methodConfig"][0]["retryPolicy"]["maxBackoff"][:-1] +# ) +# self.backoff_multiplier = self.service_config["methodConfig"][0]["retryPolicy"][ +# "backoffMultiplier" +# ] +# self.retryable_status_codes = self.service_config["methodConfig"][0][ +# "retryPolicy" +# ]["retryableStatusCodes"] +# self.namespace = namespace +# self.name = name +# self.vector_field = vector_field +# self.dimensions = dimensions + + +# def calculate_expected_time( +# max_attempts, +# initial_backoff, +# backoff_multiplier, +# max_backoff, +# retryable_status_codes, +# ): + +# current_backkoff = initial_backoff + +# expected_time = 0 +# for attempt in range(max_attempts - 1): +# expected_time += current_backkoff +# current_backkoff *= backoff_multiplier +# current_backkoff = min(current_backkoff, max_backoff) + +# return expected_time + + +# @pytest.mark.parametrize( +# "test_case", +# [ +# service_config_test_case( +# service_config_path="service_configs/retries.json", +# namespace="test", +# name="service_config_index_1", +# vector_field="example_1", +# dimensions=1024, +# ) +# ], +# ) +# async def test_admin_client_service_config_retries( +# host, +# port, +# username, +# password, +# root_certificate, +# certificate_chain, +# private_key, +# ssl_target_name_override, +# test_case, +# ): + +# if root_certificate: +# with open(root_certificate, "rb") as f: +# root_certificate = f.read() + +# if certificate_chain: +# with open(certificate_chain, "rb") as f: +# certificate_chain = f.read() +# if private_key: +# with open(private_key, "rb") as f: +# private_key = f.read() + +# async with AdminClient( +# seeds=types.HostPort(host=host, port=port), +# username=username, +# password=password, +# root_certificate=root_certificate, +# certificate_chain=certificate_chain, +# private_key=private_key, +# ssl_target_name_override=ssl_target_name_override, +# service_config_path=test_case.service_config_path, +# ) as client: +# try: +# await client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) +# except: +# pass +# expected_time = calculate_expected_time( +# test_case.max_attempts, +# test_case.initial_backoff, +# test_case.backoff_multiplier, +# test_case.max_backoff, +# test_case.retryable_status_codes, +# ) +# start_time = time.time() + +# with pytest.raises(AVSServerError) as e_info: +# await client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) + +# end_time = time.time() +# elapsed_time = end_time - start_time + +# assert abs(elapsed_time - expected_time) < 1.5 + + +# @pytest.mark.parametrize( +# "test_case", +# [ +# service_config_test_case( +# service_config_path="service_configs/initial_backoff.json", +# namespace="test", +# name="service_config_index_2", +# vector_field="example_1", +# dimensions=1024, +# ) +# ], +# ) +# async def test_admin_client_service_config_initial_backoff( +# host, +# port, +# username, +# password, +# root_certificate, +# certificate_chain, +# private_key, +# ssl_target_name_override, +# test_case, +# ): + +# if root_certificate: +# with open(root_certificate, "rb") as f: +# root_certificate = f.read() + +# if certificate_chain: +# with open(certificate_chain, "rb") as f: +# certificate_chain = f.read() +# if private_key: +# with open(private_key, "rb") as f: +# private_key = f.read() + +# async with AdminClient( +# seeds=types.HostPort(host=host, port=port), +# username=username, +# password=password, +# root_certificate=root_certificate, +# certificate_chain=certificate_chain, +# private_key=private_key, +# ssl_target_name_override=ssl_target_name_override, +# service_config_path=test_case.service_config_path, +# ) as client: + +# try: +# await client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) +# except: +# pass + +# expected_time = calculate_expected_time( +# test_case.max_attempts, +# test_case.initial_backoff, +# test_case.backoff_multiplier, +# test_case.max_backoff, +# test_case.retryable_status_codes, +# ) +# start_time = time.time() + +# with pytest.raises(AVSServerError) as e_info: +# await client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) + +# end_time = time.time() +# elapsed_time = end_time - start_time + +# assert abs(elapsed_time - expected_time) < 1.5 + + +# @pytest.mark.parametrize( +# "test_case", +# [ +# service_config_test_case( +# service_config_path="service_configs/max_backoff.json", +# namespace="test", +# name="service_config_index_3", +# vector_field="example_1", +# dimensions=1024, +# ), +# service_config_test_case( +# service_config_path="service_configs/max_backoff_lower_than_initial.json", +# namespace="test", +# name="service_config_index_4", +# vector_field="example_1", +# dimensions=1024, +# ), +# ], +# ) +# async def test_admin_client_service_config_max_backoff( +# host, +# port, +# username, +# password, +# root_certificate, +# certificate_chain, +# private_key, +# ssl_target_name_override, +# test_case, +# ): + +# if root_certificate: +# with open(root_certificate, "rb") as f: +# root_certificate = f.read() + +# if certificate_chain: +# with open(certificate_chain, "rb") as f: +# certificate_chain = f.read() +# if private_key: +# with open(private_key, "rb") as f: +# private_key = f.read() + +# async with AdminClient( +# seeds=types.HostPort(host=host, port=port), +# username=username, +# password=password, +# root_certificate=root_certificate, +# certificate_chain=certificate_chain, +# private_key=private_key, +# ssl_target_name_override=ssl_target_name_override, +# service_config_path=test_case.service_config_path, +# ) as client: + +# try: +# await client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) +# except: +# pass +# expected_time = calculate_expected_time( +# test_case.max_attempts, +# test_case.initial_backoff, +# test_case.backoff_multiplier, +# test_case.max_backoff, +# test_case.retryable_status_codes, +# ) +# start_time = time.time() + +# with pytest.raises(AVSServerError) as e_info: +# await client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) + +# end_time = time.time() +# elapsed_time = end_time - start_time +# assert abs(elapsed_time - expected_time) < 1.5 + + +# @pytest.mark.parametrize( +# "test_case", +# [ +# service_config_test_case( +# service_config_path="service_configs/backoff_multiplier.json", +# namespace="test", +# name="service_config_index_5", +# vector_field="example_1", +# dimensions=1024, +# ) +# ], +# ) +# async def test_admin_client_service_config_backoff_multiplier( +# host, +# port, +# username, +# password, +# root_certificate, +# certificate_chain, +# private_key, +# ssl_target_name_override, +# test_case, +# ): + +# if root_certificate: +# with open(root_certificate, "rb") as f: +# root_certificate = f.read() + +# if certificate_chain: +# with open(certificate_chain, "rb") as f: +# certificate_chain = f.read() +# if private_key: +# with open(private_key, "rb") as f: +# private_key = f.read() + +# async with AdminClient( +# seeds=types.HostPort(host=host, port=port), +# username=username, +# password=password, +# root_certificate=root_certificate, +# certificate_chain=certificate_chain, +# private_key=private_key, +# ssl_target_name_override=ssl_target_name_override, +# service_config_path=test_case.service_config_path, +# ) as client: + +# try: + +# await client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) +# except: +# pass + +# expected_time = calculate_expected_time( +# test_case.max_attempts, +# test_case.initial_backoff, +# test_case.backoff_multiplier, +# test_case.max_backoff, +# test_case.retryable_status_codes, +# ) +# start_time = time.time() + +# with pytest.raises(AVSServerError) as e_info: +# await client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) + +# end_time = time.time() +# elapsed_time = end_time - start_time +# assert abs(elapsed_time - expected_time) < 1.5 + + +# @pytest.mark.parametrize( +# "test_case", +# [ +# service_config_test_case( +# service_config_path="service_configs/retryable_status_codes.json", +# namespace="test", +# name="service_config_index_6", +# vector_field=None, +# dimensions=None, +# ) +# ], +# ) +# async def test_admin_client_service_config_retryable_status_codes( +# host, +# port, +# username, +# password, +# root_certificate, +# certificate_chain, +# private_key, +# ssl_target_name_override, +# test_case, +# ): + +# if root_certificate: +# with open(root_certificate, "rb") as f: +# root_certificate = f.read() + +# if certificate_chain: +# with open(certificate_chain, "rb") as f: +# certificate_chain = f.read() +# if private_key: +# with open(private_key, "rb") as f: +# private_key = f.read() + +# async with AdminClient( +# seeds=types.HostPort(host=host, port=port), +# username=username, +# password=password, +# root_certificate=root_certificate, +# certificate_chain=certificate_chain, +# private_key=private_key, +# ssl_target_name_override=ssl_target_name_override, +# service_config_path=test_case.service_config_path, +# ) as client: + +# expected_time = calculate_expected_time( +# test_case.max_attempts, +# test_case.initial_backoff, +# test_case.backoff_multiplier, +# test_case.max_backoff, +# test_case.retryable_status_codes, +# ) +# start_time = time.time() + +# with pytest.raises(AVSServerError) as e_info: +# await client.index_get_status( +# namespace=test_case.namespace, +# name=test_case.name, +# ) + +# end_time = time.time() +# elapsed_time = end_time - start_time +# assert abs(elapsed_time - expected_time) < 2 diff --git a/tests/standard/aio/test_vector_client_delete.py b/tests/standard/aio/test_vector_client_delete.py index 8daf87e6..f5479095 100644 --- a/tests/standard/aio/test_vector_client_delete.py +++ b/tests/standard/aio/test_vector_client_delete.py @@ -1,6 +1,6 @@ import pytest from aerospike_vector_search import AVSServerError -from ...utils import random_key +from utils import random_key, DEFAULT_NAMESPACE from hypothesis import given, settings, Verbosity import grpc @@ -11,13 +11,11 @@ def __init__( self, *, namespace, - record_data, set_name, timeout, ): self.namespace = namespace self.set_name = set_name - self.record_data = record_data self.timeout = timeout @@ -27,33 +25,27 @@ def __init__( "test_case", [ delete_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, set_name=None, - record_data={"skills": [i for i in range(1024)]}, timeout=None, ), delete_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, set_name=None, - record_data={"english": [float(i) for i in range(1024)]}, timeout=None, ), ], ) -async def test_vector_delete(session_vector_client, test_case, random_key): - await session_vector_client.upsert( - namespace=test_case.namespace, - key=random_key, - record_data=test_case.record_data, - set_name=test_case.set_name, - ) +async def test_vector_delete(session_vector_client, test_case, record): await session_vector_client.delete( namespace=test_case.namespace, - key=random_key, + key=record, + set_name=test_case.set_name, + timeout=test_case.timeout, ) with pytest.raises(AVSServerError) as e_info: result = await session_vector_client.get( - namespace=test_case.namespace, key=random_key + namespace=test_case.namespace, key=record ) @@ -63,9 +55,8 @@ async def test_vector_delete(session_vector_client, test_case, random_key): "test_case", [ delete_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, set_name=None, - record_data={"skills": [i for i in range(1024)]}, timeout=None, ), ], @@ -86,15 +77,14 @@ async def test_vector_delete_without_record( [ None, delete_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, set_name=None, - record_data={"skills": [i for i in range(1024)]}, timeout=0.0001, ), ], ) async def test_vector_delete_timeout( - session_vector_client, test_case, random_key, with_latency + session_vector_client, test_case, record, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") @@ -102,6 +92,6 @@ async def test_vector_delete_timeout( for i in range(10): await session_vector_client.delete( - namespace=test_case.namespace, key=random_key, timeout=test_case.timeout + namespace=test_case.namespace, key=record, timeout=test_case.timeout ) assert e_info.value.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED diff --git a/tests/standard/aio/test_vector_client_exists.py b/tests/standard/aio/test_vector_client_exists.py index 699110fa..42cdf5c9 100644 --- a/tests/standard/aio/test_vector_client_exists.py +++ b/tests/standard/aio/test_vector_client_exists.py @@ -1,6 +1,8 @@ import pytest import grpc -from ...utils import random_key + +from aerospike_vector_search import AVSServerError +from utils import DEFAULT_NAMESPACE from hypothesis import given, settings, Verbosity @@ -10,13 +12,11 @@ def __init__( self, *, namespace, - record_data, set_name, timeout, ): self.namespace = namespace self.set_name = set_name - self.record_data = record_data self.timeout = timeout @@ -26,37 +26,24 @@ def __init__( "test_case", [ exists_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, set_name=None, - record_data={"skills": [i for i in range(1024)]}, timeout=None, ), exists_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, set_name=None, - record_data={"english": [float(i) for i in range(1024)]}, timeout=None, ), ], ) -async def test_vector_exists(session_vector_client, test_case, random_key): - await session_vector_client.upsert( - namespace=test_case.namespace, - key=random_key, - record_data=test_case.record_data, - set_name=test_case.set_name, - ) +async def test_vector_exists(session_vector_client, test_case, record): result = await session_vector_client.exists( namespace=test_case.namespace, - key=random_key, + key=record, ) assert result is True - await session_vector_client.delete( - namespace=test_case.namespace, - key=random_key, - ) - #@given(random_key=key_strategy()) #@settings(max_examples=1, deadline=1000) @@ -64,18 +51,18 @@ async def test_vector_exists(session_vector_client, test_case, random_key): "test_case", [ exists_test_case( - namespace="test", set_name=None, record_data=None, timeout=0.0001 + namespace="test", set_name=None, timeout=0.0001 ), ], ) async def test_vector_exists_timeout( - session_vector_client, test_case, random_key, with_latency + session_vector_client, test_case, record, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") with pytest.raises(AVSServerError) as e_info: for i in range(10): result = await session_vector_client.exists( - namespace=test_case.namespace, key=random_key, timeout=test_case.timeout + namespace=test_case.namespace, key=record, timeout=test_case.timeout ) assert e_info.value.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED diff --git a/tests/standard/aio/test_vector_client_get.py b/tests/standard/aio/test_vector_client_get.py index b354d773..305ca760 100644 --- a/tests/standard/aio/test_vector_client_get.py +++ b/tests/standard/aio/test_vector_client_get.py @@ -1,8 +1,8 @@ import pytest -from aerospike_vector_search import AVSServerError import grpc -from ...utils import random_key +from aerospike_vector_search import AVSServerError +from utils import DEFAULT_NAMESPACE, random_key from hypothesis import given, settings, Verbosity @@ -15,7 +15,6 @@ def __init__( include_fields, exclude_fields, set_name, - record_data, expected_fields, timeout, ): @@ -23,7 +22,6 @@ def __init__( self.include_fields = include_fields self.exclude_fields = exclude_fields self.set_name = set_name - self.record_data = record_data self.expected_fields = expected_fields self.timeout = timeout @@ -31,83 +29,92 @@ def __init__( #@given(random_key=key_strategy()) #@settings(max_examples=1, deadline=1000) @pytest.mark.parametrize( - "test_case", + "record,test_case", [ - get_test_case( - namespace="test", - include_fields=["skills"], - exclude_fields = None, - set_name=None, - record_data={"skills": [i for i in range(1024)]}, - expected_fields={"skills": [i for i in range(1024)]}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"skills": 1024}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=["skills"], + exclude_fields = None, + set_name=None, + expected_fields={"skills": 1024}, + timeout=None, + ), ), - get_test_case( - namespace="test", - include_fields=["english"], - exclude_fields = None, - set_name=None, - record_data={"english": [float(i) for i in range(1024)]}, - expected_fields={"english": [float(i) for i in range(1024)]}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": [float(i) for i in range(1024)]}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=["english"], + exclude_fields = None, + set_name=None, + expected_fields={"english": [float(i) for i in range(1024)]}, + timeout=None, + ), ), - get_test_case( - namespace="test", - include_fields=["english"], - exclude_fields = None, - set_name=None, - record_data={"english": 1, "spanish": 2}, - expected_fields={"english": 1}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": 1, "spanish": 2}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=["english"], + exclude_fields = None, + set_name=None, + expected_fields={"english": 1}, + timeout=None, + ), ), - get_test_case( - namespace="test", - include_fields=None, - exclude_fields=["spanish"], - set_name=None, - record_data={"english": 1, "spanish": 2}, - expected_fields={"english": 1}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": 1, "spanish": 2}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=None, + exclude_fields=["spanish"], + set_name=None, + expected_fields={"english": 1}, + timeout=None, + ), ), - get_test_case( - namespace="test", - include_fields=["spanish"], - exclude_fields=["spanish"], - set_name=None, - record_data={"english": 1, "spanish": 2}, - expected_fields={}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": 1, "spanish": 2}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=["spanish"], + exclude_fields=["spanish"], + set_name=None, + expected_fields={}, + timeout=None, + ), ), - get_test_case( - namespace="test", - include_fields=[], - exclude_fields=None, - set_name=None, - record_data={"english": 1, "spanish": 2}, - expected_fields={}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": 1, "spanish": 2}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=[], + exclude_fields=None, + set_name=None, + expected_fields={}, + timeout=None, + ), ), - get_test_case( - namespace="test", - include_fields=None, - exclude_fields=[], - set_name=None, - record_data={"english": 1, "spanish": 2}, - expected_fields={"english": 1, "spanish": 2}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": 1, "spanish": 2}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=None, + exclude_fields=[], + set_name=None, + expected_fields={"english": 1, "spanish": 2}, + timeout=None, + ), ), ], + indirect=["record"], ) -async def test_vector_get(session_vector_client, test_case, random_key): - await session_vector_client.upsert( - namespace=test_case.namespace, - key=random_key, - record_data=test_case.record_data, - set_name=test_case.set_name, - ) +async def test_vector_get(session_vector_client, test_case, record): result = await session_vector_client.get( namespace=test_case.namespace, - key=random_key, + key=record, include_fields=test_case.include_fields, exclude_fields=test_case.exclude_fields, ) @@ -115,15 +122,10 @@ async def test_vector_get(session_vector_client, test_case, random_key): if test_case.set_name == None: test_case.set_name = "" assert result.key.set == test_case.set_name - assert result.key.key == random_key + assert result.key.key == record assert result.fields == test_case.expected_fields - await session_vector_client.delete( - namespace=test_case.namespace, - key=random_key, - ) - #@given(random_key=key_strategy()) #@settings(max_examples=1, deadline=1000) @@ -135,7 +137,6 @@ async def test_vector_get(session_vector_client, test_case, random_key): include_fields=["skills"], exclude_fields = None, set_name=None, - record_data=None, expected_fields=None, timeout=0.0001, ), diff --git a/tests/standard/aio/test_vector_client_insert.py b/tests/standard/aio/test_vector_client_insert.py index ec503c5b..04e587cf 100644 --- a/tests/standard/aio/test_vector_client_insert.py +++ b/tests/standard/aio/test_vector_client_insert.py @@ -1,6 +1,6 @@ import pytest from aerospike_vector_search import AVSServerError -from ...utils import random_key +from utils import random_key, DEFAULT_NAMESPACE from hypothesis import given, settings, Verbosity import asyncio @@ -30,19 +30,19 @@ def __init__( "test_case", [ insert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, timeout=None, ), insert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"homeSkills": [float(i) for i in range(1024)]}, set_name=None, timeout=None, ), insert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"english": [bool(i) for i in range(1024)]}, set_name=None, timeout=None, @@ -62,6 +62,10 @@ async def test_vector_insert_without_existing_record( record_data=test_case.record_data, set_name=test_case.set_name, ) + await session_vector_client.delete( + namespace=test_case.namespace, + key=random_key, + ) #@given(random_key=key_strategy()) @@ -70,7 +74,7 @@ async def test_vector_insert_without_existing_record( "test_case", [ insert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, timeout=None, @@ -78,29 +82,15 @@ async def test_vector_insert_without_existing_record( ], ) async def test_vector_insert_with_existing_record( - session_vector_client, test_case, random_key + session_vector_client, test_case, record ): - try: - await session_vector_client.insert( - namespace=test_case.namespace, - key=random_key, - record_data=test_case.record_data, - set_name=test_case.set_name, - ) - except Exception as e: - pass - with pytest.raises(AVSServerError) as e_info: await session_vector_client.insert( namespace=test_case.namespace, - key=random_key, + key=record, record_data=test_case.record_data, set_name=test_case.set_name, ) - await session_vector_client.delete( - namespace=test_case.namespace, - key=random_key, - ) #@given(random_key=key_strategy()) @@ -109,7 +99,7 @@ async def test_vector_insert_with_existing_record( "test_case", [ insert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, timeout=0.0001, @@ -130,4 +120,8 @@ async def test_vector_insert_timeout( set_name=test_case.set_name, timeout=test_case.timeout, ) + await session_vector_client.delete( + namespace=test_case.namespace, + key=random_key, + ) assert e_info.value.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED diff --git a/tests/standard/aio/test_vector_client_is_indexed.py b/tests/standard/aio/test_vector_client_is_indexed.py index d1ae8519..f222d6b3 100644 --- a/tests/standard/aio/test_vector_client_is_indexed.py +++ b/tests/standard/aio/test_vector_client_is_indexed.py @@ -1,81 +1,45 @@ import pytest -import random +import time + +from utils import DEFAULT_NAMESPACE +from .aio_utils import wait_for_index from aerospike_vector_search import AVSServerError -from aerospike_vector_search.aio import Client, AdminClient -import asyncio import grpc -# Define module-level constants for common arguments -NAMESPACE = "test" -SET_NAME = "isidxset" -INDEX_NAME = "isidx" -DIMENSIONS = 3 -VECTOR_FIELD = "vector" - -@pytest.fixture(scope="module", autouse=True) -async def setup_teardown(session_vector_client: Client, session_admin_client: AdminClient): - # Setup: Create index and upsert records - await session_admin_client.index_create( - namespace=NAMESPACE, - sets=SET_NAME, - name=INDEX_NAME, - dimensions=DIMENSIONS, - vector_field=VECTOR_FIELD, - ) - tasks = [] - for i in range(10): - tasks.append(session_vector_client.upsert( - namespace=NAMESPACE, - set_name=SET_NAME, - key=str(i), - record_data={VECTOR_FIELD: [float(i)] * DIMENSIONS}, - )) - tasks.append(session_vector_client.wait_for_index_completion( - namespace=NAMESPACE, - name=INDEX_NAME, - )) - await asyncio.gather(*tasks) - yield - # Teardown: remove records - tasks = [] - for i in range(10): - tasks.append(session_vector_client.delete( - namespace=NAMESPACE, - set_name=SET_NAME, - key=str(i) - )) - await asyncio.gather(*tasks) - - # Teardown: Drop index - await session_admin_client.index_drop( - namespace=NAMESPACE, - name=INDEX_NAME - ) async def test_vector_is_indexed( + session_admin_client, session_vector_client, + index, + record ): + # wait for the record to be indexed + await wait_for_index( + admin_client=session_admin_client, + namespace=DEFAULT_NAMESPACE, + index=index + ) + result = await session_vector_client.is_indexed( - namespace=NAMESPACE, - key="0", - index_name=INDEX_NAME, - set_name=SET_NAME, + namespace=DEFAULT_NAMESPACE, + key=record, + index_name=index, ) assert result is True async def test_vector_is_indexed_timeout( - session_vector_client, with_latency + session_vector_client, with_latency, index, record ): if not with_latency: pytest.skip("Server latency too low to test timeout") for _ in range(10): try: await session_vector_client.is_indexed( - namespace=NAMESPACE, - key="0", - index_name=INDEX_NAME, + namespace=DEFAULT_NAMESPACE, + key=record, + index_name=index, timeout=0.0001, ) except AVSServerError as se: diff --git a/tests/standard/aio/test_vector_client_search_by_key.py b/tests/standard/aio/test_vector_client_search_by_key.py index faabbec5..9434e1d1 100644 --- a/tests/standard/aio/test_vector_client_search_by_key.py +++ b/tests/standard/aio/test_vector_client_search_by_key.py @@ -1,14 +1,21 @@ import numpy as np -import asyncio import pytest + +from utils import DEFAULT_NAMESPACE +from .aio_utils import wait_for_index from aerospike_vector_search import types +INDEX = "sbk_index" +NAMESPACE = DEFAULT_NAMESPACE +DIMENSIONS = 3 +VEC_BIN = "vector" +SET_NAME = "test_set" + class vector_search_by_key_test_case: def __init__( self, *, - index_name, index_dimensions, vector_field, limit, @@ -18,10 +25,8 @@ def __init__( include_fields, exclude_fields, key_set, - record_data, expected_results, ): - self.index_name = index_name self.index_dimensions = index_dimensions self.vector_field = vector_field self.limit = limit @@ -30,45 +35,120 @@ def __init__( self.include_fields = include_fields self.exclude_fields = exclude_fields self.key_set = key_set - self.record_data = record_data self.expected_results = expected_results self.key_namespace = key_namespace -# TODO add a teardown + +@pytest.fixture(scope="module") +async def setup_index( + session_admin_client, +): + await session_admin_client.index_create( + namespace=DEFAULT_NAMESPACE, + name=INDEX, + vector_field=VEC_BIN, + dimensions=DIMENSIONS, + index_params=types.HnswParams( + batching_params=types.HnswBatchingParams( + # 10_000 is the minimum value, in order for the tests to run as + # fast as possible we set it to the minimum value so records are indexed + # quickly + index_interval=10_000, + ), + healer_params=types.HnswHealerParams( + # run the healer every second + # for fast indexing + schedule="* * * * * ?" + ) + ) + ) + + yield + + await session_admin_client.index_drop( + namespace=DEFAULT_NAMESPACE, + name=INDEX, + ) + + +@pytest.fixture(scope="module", autouse=True) +async def setup_records( + session_vector_client, +): + recs = { + "rec1": { + "bin": 1, + VEC_BIN: [1.0] * DIMENSIONS, + }, + 2: { + "bin": 2, + VEC_BIN: [2.0] * DIMENSIONS, + }, + bytes("rec5", "utf-8"): { + "bin": 5, + VEC_BIN: [5.0] * DIMENSIONS, + }, + } + + keys = [] + for key, record in recs.items(): + keys.append(key) + await session_vector_client.upsert( + namespace=DEFAULT_NAMESPACE, + key=key, + record_data=record, + ) + + # write some records for set tests + set_recs = { + "srec100": { + "bin": 100, + VEC_BIN: [100.0] * DIMENSIONS, + }, + "srec101": { + "bin": 101, + VEC_BIN: [101.0] * DIMENSIONS, + }, + } + + for key, record in set_recs.items(): + keys.append(key) + await session_vector_client.upsert( + namespace=DEFAULT_NAMESPACE, + key=key, + record_data=record, + set_name=SET_NAME, + ) + + yield + + for key in keys: + await session_vector_client.delete( + namespace=DEFAULT_NAMESPACE, + key=key, + ) + + + #@settings(max_examples=1, deadline=1000) @pytest.mark.parametrize( "test_case", [ # test string key vector_search_by_key_test_case( - index_name="basic_search", index_dimensions=3, vector_field="vector", limit=2, key="rec1", - key_namespace="test", - search_namespace="test", + key_namespace=DEFAULT_NAMESPACE, + search_namespace=DEFAULT_NAMESPACE, include_fields=None, exclude_fields=None, key_set=None, - record_data={ - "rec1": { - "bin": 1, - "vector": [1.0, 1.0, 1.0], - }, - "rec2": { - "bin": 2, - "vector": [2.0, 2.0, 2.0], - }, - "rec3": { - "bin": 3, - "vector": [3.0, 3.0, 3.0], - }, - }, expected_results=[ types.Neighbor( key=types.Key( - namespace="test", + namespace=DEFAULT_NAMESPACE, set="", key="rec1", ), @@ -80,9 +160,9 @@ def __init__( ), types.Neighbor( key=types.Key( - namespace="test", + namespace=DEFAULT_NAMESPACE, set="", - key="rec2", + key=2, ), fields={ "bin": 2, @@ -94,237 +174,172 @@ def __init__( ), # test int key vector_search_by_key_test_case( - index_name="field_filter", index_dimensions=3, vector_field="vector", limit=3, - key=1, - key_namespace="test", - search_namespace="test", + key=2, + key_namespace=DEFAULT_NAMESPACE, + search_namespace=DEFAULT_NAMESPACE, include_fields=["bin"], exclude_fields=["bin"], key_set=None, - record_data={ - 1: { - "bin": 1, - "vector": [1.0, 1.0, 1.0], - }, - 2: { - "bin": 2, - "vector": [2.0, 2.0, 2.0], - }, - }, expected_results=[ types.Neighbor( key=types.Key( - namespace="test", + namespace=DEFAULT_NAMESPACE, set="", - key=1, + key=2, ), fields={}, distance=0.0, ), types.Neighbor( key=types.Key( - namespace="test", + namespace=DEFAULT_NAMESPACE, set="", - key=2, + key="rec1", ), fields={}, distance=3.0, ), + types.Neighbor( + key=types.Key( + namespace=DEFAULT_NAMESPACE, + set="", + key=bytes("rec5", "utf-8"), + ), + fields={}, + distance=27.0, + ), ], ), # test bytes key vector_search_by_key_test_case( - index_name="field_filter", index_dimensions=3, vector_field="vector", limit=3, - key=bytes("rec1", "utf-8"), - key_namespace="test", - search_namespace="test", + key=bytes("rec5", "utf-8"), + key_namespace=DEFAULT_NAMESPACE, + search_namespace=DEFAULT_NAMESPACE, include_fields=["bin"], exclude_fields=["bin"], key_set=None, - record_data={ - bytes("rec1", "utf-8"): { - "bin": 1, - "vector": [1.0, 1.0, 1.0], - }, - bytes("rec2", "utf-8"): { - "bin": 2, - "vector": [2.0, 2.0, 2.0], - }, - }, expected_results=[ types.Neighbor( key=types.Key( - namespace="test", + namespace=DEFAULT_NAMESPACE, set="", - key=bytes("rec1", "utf-8"), + key=bytes("rec5", "utf-8"), ), fields={}, distance=0.0, ), types.Neighbor( key=types.Key( - namespace="test", + namespace=DEFAULT_NAMESPACE, set="", - key=bytes("rec2", "utf-8"), + key=2, ), fields={}, - distance=3.0, + distance=27.0, + ), + types.Neighbor( + key=types.Key( + namespace=DEFAULT_NAMESPACE, + set="", + key="rec1", + ), + fields={}, + distance=48.0, ), ], ), - # test bytearray key - # TODO: add a bytearray key case, bytearrays are not hashable - # so this is not easily added. Leaving it for now. - # vector_search_by_key_test_case( - # index_name="field_filter", - # index_dimensions=3, - # vector_field="vector", - # limit=3, - # key=bytearray("rec1", "utf-8"), - # namespace="test", - # include_fields=["bin"], - # exclude_fields=["bin"], - # key_set=None, - # record_data={ - # bytearray("rec1", "utf-8"): { - # "bin": 1, - # "vector": [1.0, 1.0, 1.0], - # }, - # bytearray("rec1", "utf-8"): { - # "bin": 2, - # "vector": [2.0, 2.0, 2.0], - # }, - # }, - # expected_results=[ - # types.Neighbor( - # key=types.Key( - # namespace="test", - # set="", - # key=2, - # ), - # fields={}, - # distance=3.0, - # ), - # ], - # ), + # # test bytearray key + # # TODO: add a bytearray key case, bytearrays are not hashable + # # so this is not easily added. Leaving it for now. + # # vector_search_by_key_test_case( + # # index_name="field_filter", + # # index_dimensions=3, + # # vector_field="vector", + # # limit=3, + # # key=bytearray("rec1", "utf-8"), + # # namespace=DEFAULT_NAMESPACE, + # # include_fields=["bin"], + # # exclude_fields=["bin"], + # # key_set=None, + # # record_data={ + # # bytearray("rec1", "utf-8"): { + # # "bin": 1, + # # "vector": [1.0, 1.0, 1.0], + # # }, + # # bytearray("rec1", "utf-8"): { + # # "bin": 2, + # # "vector": [2.0, 2.0, 2.0], + # # }, + # # }, + # # expected_results=[ + # # types.Neighbor( + # # key=types.Key( + # # namespace=DEFAULT_NAMESPACE, + # # set="", + # # key=2, + # # ), + # # fields={}, + # # distance=3.0, + # # ), + # # ], + # # ), # test with set name vector_search_by_key_test_case( - index_name="basic_search", index_dimensions=3, vector_field="vector", limit=2, - key="rec1", - key_namespace="test", - search_namespace="test", + key="srec100", + key_namespace=DEFAULT_NAMESPACE, + search_namespace=DEFAULT_NAMESPACE, include_fields=None, exclude_fields=None, - key_set="test_set", - record_data={ - "rec1": { - "bin": 1, - "vector": [1.0, 1.0, 1.0], - }, - "rec2": { - "bin": 2, - "vector": [2.0, 2.0, 2.0], - }, - }, + key_set=SET_NAME, expected_results=[ types.Neighbor( key=types.Key( - namespace="test", - set="test_set", - key="rec1", + namespace=DEFAULT_NAMESPACE, + set=SET_NAME, + key="srec100", ), fields={ - "bin": 1, - "vector": [1.0, 1.0, 1.0], + "bin": 100, + "vector": [100.0] * DIMENSIONS, }, distance=0.0, ), types.Neighbor( key=types.Key( - namespace="test", - set="test_set", - key="rec2", + namespace=DEFAULT_NAMESPACE, + set=SET_NAME, + key="srec101", ), fields={ - "bin": 2, - "vector": [2.0, 2.0, 2.0], + "bin": 101, + "vector": [101.0] * DIMENSIONS, }, distance=3.0, ), ], ), - # test search key record and search records are in different namespaces - vector_search_by_key_test_case( - index_name="basic_search", - index_dimensions=3, - vector_field="vector", - limit=2, - key="rec1", - key_namespace="test", - search_namespace="index_storage", - include_fields=None, - exclude_fields=None, - key_set=None, - record_data={ - "rec1": { - "bin": 1, - "vector": [1.0, 1.0, 1.0], - }, - "rec2": { - "bin": 2, - "vector": [2.0, 2.0, 2.0], - }, - "rec3": { - "bin": 3, - "vector": [3.0, 3.0, 3.0], - }, - }, - expected_results=[], - ), ], ) async def test_vector_search_by_key( session_vector_client, session_admin_client, + setup_index, test_case, ): - - await session_admin_client.index_create( - namespace=test_case.search_namespace, - name=test_case.index_name, - vector_field=test_case.vector_field, - dimensions=test_case.index_dimensions, - ) - - tasks = [] - for key, rec in test_case.record_data.items(): - tasks.append(session_vector_client.upsert( - namespace=test_case.key_namespace, - key=key, - record_data=rec, - set_name=test_case.key_set, - )) - - tasks.append( - session_vector_client.wait_for_index_completion( - namespace=test_case.search_namespace, - name=test_case.index_name, - ) - ) - await asyncio.gather(*tasks) + await wait_for_index(session_admin_client, DEFAULT_NAMESPACE, INDEX) results = await session_vector_client.vector_search_by_key( search_namespace=test_case.search_namespace, - index_name=test_case.index_name, + index_name=INDEX, key=test_case.key, key_namespace=test_case.key_namespace, vector_field=test_case.vector_field, @@ -336,21 +351,6 @@ async def test_vector_search_by_key( assert results == test_case.expected_results - tasks = [] - for key in test_case.record_data: - tasks.append(session_vector_client.delete( - namespace=test_case.key_namespace, - set_name=test_case.key_set, - key=key, - )) - - await asyncio.gather(*tasks) - - await session_admin_client.index_drop( - namespace=test_case.search_namespace, - name=test_case.index_name, - ) - async def test_vector_search_by_key_different_namespaces( session_vector_client, @@ -362,6 +362,19 @@ async def test_vector_search_by_key_different_namespaces( name="diff_ns_idx", vector_field="vec", dimensions=3, + index_params=types.HnswParams( + batching_params=types.HnswBatchingParams( + # 10_000 is the minimum value, in order for the tests to run as + # fast as possible we set it to the minimum value so records are indexed + # quickly + index_interval=10_000, + ), + healer_params=types.HnswHealerParams( + # run the healer every second + # for fast indexing + schedule="* * * * * ?" + ) + ) ) await session_vector_client.upsert( @@ -382,10 +395,7 @@ async def test_vector_search_by_key_different_namespaces( }, ) - await session_vector_client.wait_for_index_completion( - namespace="index_storage", - name="diff_ns_idx", - ) + await wait_for_index(session_admin_client, "index_storage", "diff_ns_idx") results = await session_vector_client.vector_search_by_key( search_namespace="index_storage", diff --git a/tests/standard/aio/test_vector_client_update.py b/tests/standard/aio/test_vector_client_update.py index 6c31dcb1..fc5b5264 100644 --- a/tests/standard/aio/test_vector_client_update.py +++ b/tests/standard/aio/test_vector_client_update.py @@ -1,6 +1,6 @@ import pytest from aerospike_vector_search import AVSServerError -from ...utils import random_key +from ...utils import random_key, DEFAULT_NAMESPACE from hypothesis import given, settings, Verbosity import grpc @@ -27,19 +27,19 @@ def __init__( "test_case", [ update_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, timeout=None, ), update_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"english": [float(i) for i in range(1024)]}, set_name=None, timeout=None, ), update_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"english": [bool(i) for i in range(1024)]}, set_name=None, timeout=None, @@ -47,27 +47,14 @@ def __init__( ], ) async def test_vector_update_with_existing_record( - session_vector_client, test_case, random_key + session_vector_client, test_case, record ): - try: - await session_vector_client.insert( - namespace=test_case.namespace, - key=random_key, - record_data=test_case.record_data, - set_name=test_case.set_name, - ) - except Exception as e: - pass await session_vector_client.update( namespace=test_case.namespace, - key=random_key, + key=record, record_data=test_case.record_data, set_name=test_case.set_name, ) - await session_vector_client.delete( - namespace=test_case.namespace, - key=random_key, - ) #@given(random_key=key_strategy()) @@ -76,7 +63,7 @@ async def test_vector_update_with_existing_record( "test_case", [ update_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, timeout=None, @@ -104,9 +91,8 @@ async def test_vector_update_without_existing_record( @pytest.mark.parametrize( "test_case", [ - None, update_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, timeout=0.0001, diff --git a/tests/standard/aio/test_vector_client_upsert.py b/tests/standard/aio/test_vector_client_upsert.py index 5cacc11d..cac115b4 100644 --- a/tests/standard/aio/test_vector_client_upsert.py +++ b/tests/standard/aio/test_vector_client_upsert.py @@ -1,5 +1,5 @@ import pytest -from ...utils import random_key +from ...utils import random_key, DEFAULT_NAMESPACE from hypothesis import given, settings, Verbosity import numpy as np @@ -24,19 +24,19 @@ def __init__(self, *, namespace, record_data, set_name, timeout, key=None): "test_case", [ upsert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, timeout=None, ), upsert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"english": [float(i) for i in range(1024)]}, set_name=None, timeout=None, ), upsert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"english": [bool(i) for i in range(1024)]}, set_name=None, timeout=None, @@ -65,7 +65,7 @@ async def test_vector_upsert_without_existing_record( "test_case", [ upsert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, timeout=None, @@ -92,14 +92,14 @@ async def test_vector_upsert_with_existing_record( "test_case", [ upsert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, timeout=None, key=np.int32(31), ), upsert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, timeout=None, @@ -128,7 +128,7 @@ async def test_vector_upsert_with_numpy_key(session_vector_client, test_case): [ None, upsert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, timeout=0.0001, diff --git a/tests/standard/aio/test_vector_search.py b/tests/standard/aio/test_vector_search.py index 88cab272..ff597fe0 100644 --- a/tests/standard/aio/test_vector_search.py +++ b/tests/standard/aio/test_vector_search.py @@ -1,7 +1,10 @@ import numpy as np import asyncio import pytest + from aerospike_vector_search import types +from utils import DEFAULT_NAMESPACE +from .aio_utils import wait_for_index class vector_search_test_case: @@ -32,6 +35,7 @@ def __init__( self.record_data = record_data self.expected_results = expected_results + # TODO add a teardown #@settings(max_examples=1, deadline=1000) @pytest.mark.parametrize( @@ -40,29 +44,29 @@ def __init__( vector_search_test_case( index_name="basic_search", index_dimensions=3, - vector_field="vector", + vector_field="vecs", limit=3, query=[0.0, 0.0, 0.0], - namespace="test", + namespace=DEFAULT_NAMESPACE, include_fields=None, exclude_fields = None, set_name=None, record_data={ "rec1": { "bin1": 1, - "vector": [1.0, 1.0, 1.0], + "vecs": [1.0, 1.0, 1.0], }, }, expected_results=[ types.Neighbor( key=types.Key( - namespace="test", + namespace=DEFAULT_NAMESPACE, set="", key="rec1", ), fields={ "bin1": 1, - "vector": [1.0, 1.0, 1.0], + "vecs": [1.0, 1.0, 1.0], }, distance=3.0, ), @@ -71,23 +75,23 @@ def __init__( vector_search_test_case( index_name="field_filter", index_dimensions=3, - vector_field="vector", + vector_field="vecs", limit=3, query=[0.0, 0.0, 0.0], - namespace="test", + namespace=DEFAULT_NAMESPACE, include_fields=["bin1"], exclude_fields=["bin1"], set_name=None, record_data={ "rec1": { "bin1": 1, - "vector": [1.0, 1.0, 1.0], + "vecs": [1.0, 1.0, 1.0], }, }, expected_results=[ types.Neighbor( key=types.Key( - namespace="test", + namespace=DEFAULT_NAMESPACE, set="", key="rec1", ), @@ -109,6 +113,19 @@ async def test_vector_search( name=test_case.index_name, vector_field=test_case.vector_field, dimensions=test_case.index_dimensions, + index_params=types.HnswParams( + batching_params=types.HnswBatchingParams( + # 10_000 is the minimum value, in order for the tests to run as + # fast as possible we set it to the minimum value so records are indexed + # quickly + index_interval=10_000, + ), + healer_params=types.HnswHealerParams( + # run the healer every second + # for fast indexing + schedule="* * * * * ?" + ) + ) ) tasks = [] @@ -121,9 +138,10 @@ async def test_vector_search( )) tasks.append( - session_vector_client.wait_for_index_completion( + wait_for_index( + session_admin_client, namespace=test_case.namespace, - name=test_case.index_name, + index=test_case.index_name, ) ) await asyncio.gather(*tasks) diff --git a/tests/standard/sync/conftest.py b/tests/standard/sync/conftest.py index 6e7a229b..d6ffff5b 100644 --- a/tests/standard/sync/conftest.py +++ b/tests/standard/sync/conftest.py @@ -4,9 +4,10 @@ from aerospike_vector_search import Client from aerospike_vector_search.admin import Client as AdminClient -from aerospike_vector_search import types +from aerospike_vector_search import types, AVSServerError from .sync_utils import gen_records +import grpc #import logging #logger = logging.getLogger(__name__) @@ -195,14 +196,37 @@ def index_name(): @pytest.fixture(params=[DEFAULT_INDEX_ARGS]) def index(session_admin_client, index_name, request): - index_args = request.param + args = request.param + namespace = args.get("namespace", DEFAULT_NAMESPACE) + vector_field = args.get("vector_field", DEFAULT_VECTOR_FIELD) + dimensions = args.get("dimensions", DEFAULT_INDEX_DIMENSION) session_admin_client.index_create( name = index_name, - **index_args, + namespace = namespace, + vector_field = vector_field, + dimensions = dimensions, + index_params=types.HnswParams( + batching_params=types.HnswBatchingParams( + # 10_000 is the minimum value, in order for the tests to run as + # fast as possible we set it to the minimum value so records are indexed + # quickly + index_interval=10_000, + ), + healer_params=types.HnswHealerParams( + # run the healer every second + # for fast indexing + schedule="* * * * * ?" + ) + ) ) yield index_name - namespace = index_args.get("namespace", DEFAULT_NAMESPACE) - session_admin_client.index_drop(namespace=namespace, name=index_name) + try: + session_admin_client.index_drop(namespace=namespace, name=index_name) + except AVSServerError as se: + if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: + pass + else: + raise @pytest.fixture(params=[DEFAULT_RECORDS_ARGS]) @@ -213,10 +237,35 @@ def records(session_vector_client, request): num_records = args.get("num_records", DEFAULT_NUM_RECORDS) vector_field = args.get("vector_field", DEFAULT_VECTOR_FIELD) dimensions = args.get("dimensions", DEFAULT_INDEX_DIMENSION) + set_name = args.get("set_name", None) keys = [] for key, rec in record_generator(count=num_records, vec_bin=vector_field, vec_dim=dimensions): - session_vector_client.upsert(namespace=namespace, key=key, record_data=rec) + session_vector_client.upsert( + namespace=namespace, + key=key, + record_data=rec, + set_name=set_name, + ) keys.append(key) - yield len(keys) + yield keys for key in keys: - session_vector_client.delete(key=key, namespace=namespace) \ No newline at end of file + session_vector_client.delete(key=key, namespace=namespace) + + +@pytest.fixture(params=[DEFAULT_RECORDS_ARGS]) +def record(session_vector_client, request): + args = request.param + record_generator = args.get("record_generator", DEFAULT_RECORD_GENERATOR) + namespace = args.get("namespace", DEFAULT_NAMESPACE) + vector_field = args.get("vector_field", DEFAULT_VECTOR_FIELD) + dimensions = args.get("dimensions", DEFAULT_INDEX_DIMENSION) + set_name = args.get("set_name", None) + key, rec = next(record_generator(count=1, vec_bin=vector_field, vec_dim=dimensions)) + session_vector_client.upsert( + namespace=namespace, + key=key, + record_data=rec, + set_name=set_name, + ) + yield key + session_vector_client.delete(key=key, namespace=namespace) \ No newline at end of file diff --git a/tests/standard/sync/sync_utils.py b/tests/standard/sync/sync_utils.py index b3936c8a..05aeeb94 100644 --- a/tests/standard/sync/sync_utils.py +++ b/tests/standard/sync/sync_utils.py @@ -1,6 +1,9 @@ +import time + def drop_specified_index(admin_client, namespace, name): admin_client.index_drop(namespace=namespace, name=name) + def gen_records(count: int, vec_bin: str, vec_dim: int): num = 0 while num < count: @@ -10,3 +13,22 @@ def gen_records(count: int, vec_bin: str, vec_dim: int): ) yield key_and_rec num += 1 + + +def wait_for_index(admin_client, namespace: str, index: str): + + verticies = 0 + unmerged_recs = 0 + + while verticies == 0 or unmerged_recs > 0: + status = admin_client.index_get_status( + namespace=namespace, + name=index, + ) + + verticies = status.index_healer_vertices_valid + unmerged_recs = status.unmerged_record_count + + # print(verticies) + # print(unmerged_recs) + time.sleep(0.5) \ No newline at end of file diff --git a/tests/standard/sync/test_admin_client_index_create.py b/tests/standard/sync/test_admin_client_index_create.py index dd23c75f..3b90fdba 100644 --- a/tests/standard/sync/test_admin_client_index_create.py +++ b/tests/standard/sync/test_admin_client_index_create.py @@ -2,9 +2,9 @@ import grpc from aerospike_vector_search import types, AVSServerError -from ...utils import random_name - +from ...utils import random_name, DEFAULT_NAMESPACE from .sync_utils import drop_specified_index + from hypothesis import given, settings, Verbosity server_defaults = { @@ -51,7 +51,7 @@ def __init__( "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_1", dimensions=1024, vector_distance_metric=None, @@ -65,7 +65,7 @@ def __init__( ) def test_index_create(session_admin_client, test_case, random_name): try: - session_admin_client.index_drop(namespace="test", name=random_name) + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass @@ -110,7 +110,7 @@ def test_index_create(session_admin_client, test_case, random_name): "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_2", dimensions=495, vector_distance_metric=None, @@ -121,7 +121,7 @@ def test_index_create(session_admin_client, test_case, random_name): timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_3", dimensions=2048, vector_distance_metric=None, @@ -136,7 +136,7 @@ def test_index_create(session_admin_client, test_case, random_name): def test_index_create_with_dimnesions(session_admin_client, test_case, random_name): try: - session_admin_client.index_drop(namespace="test", name=random_name) + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass @@ -184,7 +184,7 @@ def test_index_create_with_dimnesions(session_admin_client, test_case, random_na "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_4", dimensions=1024, vector_distance_metric=types.VectorDistanceMetric.COSINE, @@ -195,7 +195,7 @@ def test_index_create_with_dimnesions(session_admin_client, test_case, random_na timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_5", dimensions=1024, vector_distance_metric=types.VectorDistanceMetric.DOT_PRODUCT, @@ -206,7 +206,7 @@ def test_index_create_with_dimnesions(session_admin_client, test_case, random_na timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_6", dimensions=1024, vector_distance_metric=types.VectorDistanceMetric.MANHATTAN, @@ -217,7 +217,7 @@ def test_index_create_with_dimnesions(session_admin_client, test_case, random_na timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_7", dimensions=1024, vector_distance_metric=types.VectorDistanceMetric.HAMMING, @@ -234,7 +234,7 @@ def test_index_create_with_vector_distance_metric( ): try: - session_admin_client.index_drop(namespace="test", name=random_name) + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass @@ -278,7 +278,7 @@ def test_index_create_with_vector_distance_metric( "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_8", dimensions=1024, vector_distance_metric=None, @@ -289,7 +289,7 @@ def test_index_create_with_vector_distance_metric( timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_9", dimensions=1024, vector_distance_metric=None, @@ -304,7 +304,7 @@ def test_index_create_with_vector_distance_metric( def test_index_create_with_sets(session_admin_client, test_case, random_name): try: - session_admin_client.index_drop(namespace="test", name=random_name) + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass @@ -348,7 +348,7 @@ def test_index_create_with_sets(session_admin_client, test_case, random_name): "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_10", dimensions=1024, vector_distance_metric=None, @@ -364,7 +364,7 @@ def test_index_create_with_sets(session_admin_client, test_case, random_name): timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_11", dimensions=1024, vector_distance_metric=None, @@ -377,7 +377,7 @@ def test_index_create_with_sets(session_admin_client, test_case, random_name): timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_20", dimensions=1024, vector_distance_metric=None, @@ -390,7 +390,7 @@ def test_index_create_with_sets(session_admin_client, test_case, random_name): timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_12", dimensions=1024, vector_distance_metric=None, @@ -404,7 +404,7 @@ def test_index_create_with_sets(session_admin_client, test_case, random_name): timeout=None, ), index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_13", dimensions=1024, vector_distance_metric=None, @@ -432,7 +432,7 @@ def test_index_create_with_sets(session_admin_client, test_case, random_name): ) def test_index_create_with_index_params(session_admin_client, test_case, random_name): try: - session_admin_client.index_drop(namespace="test", name=random_name) + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass @@ -544,7 +544,7 @@ def test_index_create_with_index_params(session_admin_client, test_case, random_ "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_16", dimensions=1024, vector_distance_metric=None, @@ -558,7 +558,7 @@ def test_index_create_with_index_params(session_admin_client, test_case, random_ ) def test_index_create_index_labels(session_admin_client, test_case, random_name): try: - session_admin_client.index_drop(namespace="test", name=random_name) + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass @@ -605,21 +605,21 @@ def test_index_create_index_labels(session_admin_client, test_case, random_name) "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_17", dimensions=1024, vector_distance_metric=None, sets=None, index_params=None, index_labels=None, - index_storage=types.IndexStorage(namespace="test", set_name="foo"), + index_storage=types.IndexStorage(namespace=DEFAULT_NAMESPACE, set_name="foo"), timeout=None, ), ], ) def test_index_create_index_storage(session_admin_client, test_case, random_name): try: - session_admin_client.index_drop(namespace="test", name=random_name) + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass @@ -662,7 +662,7 @@ def test_index_create_index_storage(session_admin_client, test_case, random_name "test_case", [ index_create_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, vector_field="example_18", dimensions=1024, vector_distance_metric=None, @@ -681,7 +681,7 @@ def test_index_create_timeout( if not with_latency: pytest.skip("Server latency too low to test timeout") try: - session_admin_client.index_drop(namespace="test", name=random_name) + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=random_name) except AVSServerError as se: if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: pass diff --git a/tests/standard/sync/test_admin_client_index_drop.py b/tests/standard/sync/test_admin_client_index_drop.py index 0523510a..cf05a88f 100644 --- a/tests/standard/sync/test_admin_client_index_drop.py +++ b/tests/standard/sync/test_admin_client_index_drop.py @@ -3,8 +3,7 @@ from aerospike_vector_search import AVSServerError import grpc -from ...utils import random_name - +from utils import DEFAULT_NAMESPACE from hypothesis import given, settings, Verbosity @@ -12,52 +11,31 @@ @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -def test_index_drop(session_admin_client, empty_test_case, random_name): - try: - - session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="art", - dimensions=1024, - ) - - except AVSServerError as se: - if se.rpc_error.code() != grpc.StatusCode.ALREADY_EXISTS: - raise se - - session_admin_client.index_drop(namespace="test", name=random_name) +def test_index_drop(session_admin_client, empty_test_case, index): + session_admin_client.index_drop(namespace=DEFAULT_NAMESPACE, name=index) result = session_admin_client.index_list() result = result for index in result: - assert index["id"]["name"] != random_name + assert index["id"]["name"] != index @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) def test_index_drop_timeout( - session_admin_client, empty_test_case, random_name, with_latency + session_admin_client, + empty_test_case, + index, + with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") - try: - session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="art", - dimensions=1024, - ) - except AVSServerError as se: - if se.rpc_error.code() != grpc.StatusCode.ALREADY_EXISTS: - raise se - for i in range(10): try: session_admin_client.index_drop( - namespace="test", name=random_name, timeout=0.0001 + namespace=DEFAULT_NAMESPACE, name=index, timeout=0.0001 ) except AVSServerError as se: if se.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: diff --git a/tests/standard/sync/test_admin_client_index_get.py b/tests/standard/sync/test_admin_client_index_get.py index 1831bcab..b8034a71 100644 --- a/tests/standard/sync/test_admin_client_index_get.py +++ b/tests/standard/sync/test_admin_client_index_get.py @@ -1,44 +1,29 @@ -import pytest -from ...utils import random_name - -from .sync_utils import drop_specified_index -from hypothesis import given, settings, Verbosity - +from ...utils import DEFAULT_NAMESPACE, DEFAULT_INDEX_DIMENSION, DEFAULT_VECTOR_FIELD from aerospike_vector_search import AVSServerError -import grpc +import grpc +from hypothesis import given, settings, Verbosity +import pytest @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -def test_index_get(session_admin_client, empty_test_case, random_name): - - try: - session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="science", - dimensions=1024, - ) - except AVSServerError as se: - if se.rpc_error.code() != grpc.StatusCode.ALREADY_EXISTS: - raise se - - result = session_admin_client.index_get(namespace="test", name=random_name, apply_defaults=True) - - assert result["id"]["name"] == random_name - assert result["id"]["namespace"] == "test" - assert result["dimensions"] == 1024 - assert result["field"] == "science" +def test_index_get(session_admin_client, empty_test_case, index): + result = session_admin_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=True) + + assert result["id"]["name"] == index + assert result["id"]["namespace"] == DEFAULT_NAMESPACE + assert result["dimensions"] == DEFAULT_INDEX_DIMENSION + assert result["field"] == DEFAULT_VECTOR_FIELD assert result["hnsw_params"]["m"] == 16 assert result["hnsw_params"]["ef_construction"] == 100 assert result["hnsw_params"]["ef"] == 100 assert result["hnsw_params"]["batching_params"]["max_index_records"] == 100000 - assert result["hnsw_params"]["batching_params"]["index_interval"] == 30000 + assert result["hnsw_params"]["batching_params"]["index_interval"] == 10000 assert result["hnsw_params"]["batching_params"]["max_reindex_records"] == max(100000 / 10, 1000) assert result["hnsw_params"]["batching_params"]["reindex_interval"] == 30000 - assert result["storage"]["namespace"] == "test" - assert result["storage"]["set_name"] == random_name + assert result["storage"]["namespace"] == DEFAULT_NAMESPACE + assert result["storage"]["set_name"] == index # Defaults assert result["sets"] == "" @@ -51,33 +36,25 @@ def test_index_get(session_admin_client, empty_test_case, random_name): assert result["hnsw_params"]["healer_params"]["max_scan_rate_per_node"] == 1000 assert result["hnsw_params"]["healer_params"]["max_scan_page_size"] == 10000 assert result["hnsw_params"]["healer_params"]["re_index_percent"] == 10.0 - assert result["hnsw_params"]["healer_params"]["schedule"] == "0 0/15 * ? * * *" + assert result["hnsw_params"]["healer_params"]["schedule"] == "* * * * * ?" assert result["hnsw_params"]["healer_params"]["parallelism"] == 1 # index parallelism and reindex parallelism are dynamic depending on the CPU cores of the host # assert result["hnsw_params"]["merge_params"]["index_parallelism"] == 80 # assert result["hnsw_params"]["merge_params"]["reindex_parallelism"] == 26 - drop_specified_index(session_admin_client, "test", random_name) - @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -async def test_index_get_no_defaults(session_admin_client, empty_test_case, random_name): - session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="science", - dimensions=1024, - ) +async def test_index_get_no_defaults(session_admin_client, empty_test_case, index): - result = session_admin_client.index_get(namespace="test", name=random_name, apply_defaults=False) + result = session_admin_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=False) - assert result["id"]["name"] == random_name - assert result["id"]["namespace"] == "test" - assert result["dimensions"] == 1024 - assert result["field"] == "science" + assert result["id"]["name"] == index + assert result["id"]["namespace"] == DEFAULT_NAMESPACE + assert result["dimensions"] == DEFAULT_INDEX_DIMENSION + assert result["field"] == DEFAULT_VECTOR_FIELD # Defaults assert result["sets"] == "" @@ -87,7 +64,8 @@ async def test_index_get_no_defaults(session_admin_client, empty_test_case, rand assert result["hnsw_params"]["ef"] == 0 assert result["hnsw_params"]["ef_construction"] == 0 assert result["hnsw_params"]["batching_params"]["max_index_records"] == 0 - assert result["hnsw_params"]["batching_params"]["index_interval"] == 0 + # This is set by default to 10000 in the index fixture + assert result["hnsw_params"]["batching_params"]["index_interval"] == 10000 assert result["hnsw_params"]["max_mem_queue_size"] == 0 assert result["hnsw_params"]["index_caching_params"]["max_entries"] == 0 assert result["hnsw_params"]["index_caching_params"]["expiry"] == 0 @@ -95,7 +73,8 @@ async def test_index_get_no_defaults(session_admin_client, empty_test_case, rand assert result["hnsw_params"]["healer_params"]["max_scan_rate_per_node"] == 0 assert result["hnsw_params"]["healer_params"]["max_scan_page_size"] == 0 assert result["hnsw_params"]["healer_params"]["re_index_percent"] == 0 - assert result["hnsw_params"]["healer_params"]["schedule"] == "" + # This is set by default to * * * * * ? in the index fixture + assert result["hnsw_params"]["healer_params"]["schedule"] == "* * * * * ?" assert result["hnsw_params"]["healer_params"]["parallelism"] == 0 assert result["hnsw_params"]["merge_params"]["index_parallelism"] == 0 @@ -105,32 +84,20 @@ async def test_index_get_no_defaults(session_admin_client, empty_test_case, rand assert result["storage"].set_name == "" assert result["storage"]["set_name"] == "" - drop_specified_index(session_admin_client, "test", random_name) - @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) def test_index_get_timeout( - session_admin_client, empty_test_case, random_name, with_latency + session_admin_client, empty_test_case, index, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") - try: - session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="science", - dimensions=1024, - ) - except AVSServerError as se: - if se.rpc_error.code() != grpc.StatusCode.ALREADY_EXISTS: - raise se for i in range(10): try: result = session_admin_client.index_get( - namespace="test", name=random_name, timeout=0.0001 + namespace=DEFAULT_NAMESPACE, name=index, timeout=0.0001 ) except AVSServerError as se: diff --git a/tests/standard/sync/test_admin_client_index_get_status.py b/tests/standard/sync/test_admin_client_index_get_status.py index 79a99dda..0f127829 100644 --- a/tests/standard/sync/test_admin_client_index_get_status.py +++ b/tests/standard/sync/test_admin_client_index_get_status.py @@ -1,7 +1,7 @@ import pytest import grpc -from ...utils import random_name +from ...utils import DEFAULT_NAMESPACE from .sync_utils import drop_specified_index from hypothesis import given, settings, Verbosity @@ -13,28 +13,18 @@ @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -def test_index_get_status(session_admin_client, empty_test_case, random_name): - try: - session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="science", - dimensions=1024, - ) - except AVSServerError as se: - if se.rpc_error.code() != grpc.StatusCode.ALREADY_EXISTS: - raise se - result = session_admin_client.index_get_status(namespace="test", name=random_name) +def test_index_get_status(session_admin_client, empty_test_case, index): + result = session_admin_client.index_get_status(namespace=DEFAULT_NAMESPACE, name=index) assert result.unmerged_record_count == 0 - drop_specified_index(session_admin_client, "test", random_name) + drop_specified_index(session_admin_client, DEFAULT_NAMESPACE, index) @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) def test_index_get_status_timeout( - session_admin_client, empty_test_case, random_name, with_latency + session_admin_client, empty_test_case, index, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") @@ -42,7 +32,7 @@ def test_index_get_status_timeout( for i in range(10): try: result = session_admin_client.index_get_status( - namespace="test", name=random_name, timeout=0.0001 + namespace=DEFAULT_NAMESPACE, name=index, timeout=0.0001 ) except AVSServerError as se: if se.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: diff --git a/tests/standard/sync/test_admin_client_index_list.py b/tests/standard/sync/test_admin_client_index_list.py index b2a2c9e7..4a119573 100644 --- a/tests/standard/sync/test_admin_client_index_list.py +++ b/tests/standard/sync/test_admin_client_index_list.py @@ -1,24 +1,15 @@ from aerospike_vector_search import AVSServerError +from .sync_utils import drop_specified_index import pytest import grpc - -from ...utils import random_name - -from .sync_utils import drop_specified_index from hypothesis import given, settings, Verbosity @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) -def test_index_list(session_admin_client, empty_test_case, random_name): - session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="science", - dimensions=1024, - ) +def test_index_list(session_admin_client, empty_test_case, index): result = session_admin_client.index_list(apply_defaults=True) assert len(result) > 0 for index in result: @@ -35,28 +26,17 @@ def test_index_list(session_admin_client, empty_test_case, random_name): assert isinstance(index["hnsw_params"]["batching_params"]["reindex_interval"], int) assert isinstance(index["storage"]["namespace"], str) assert isinstance(index["storage"]["set_name"], str) - drop_specified_index(session_admin_client, "test", random_name) @pytest.mark.parametrize("empty_test_case", [None]) #@given(random_name=index_strategy()) #@settings(max_examples=1, deadline=1000) def test_index_list_timeout( - session_admin_client, empty_test_case, random_name, with_latency + session_admin_client, empty_test_case, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") - try: - session_admin_client.index_create( - namespace="test", - name=random_name, - vector_field="science", - dimensions=1024, - ) - except AVSServerError as se: - if se.rpc_error.code() != grpc.StatusCode.ALREADY_EXISTS: - raise se for i in range(10): diff --git a/tests/standard/sync/test_admin_client_index_update.py b/tests/standard/sync/test_admin_client_index_update.py index 77ee984d..e5be9822 100644 --- a/tests/standard/sync/test_admin_client_index_update.py +++ b/tests/standard/sync/test_admin_client_index_update.py @@ -1,16 +1,14 @@ import time -import pytest -from aerospike_vector_search import types, AVSServerError -import grpc -from .sync_utils import drop_specified_index +from aerospike_vector_search import types +from utils import DEFAULT_NAMESPACE +import pytest class index_update_test_case: def __init__( self, *, - namespace, vector_field, dimensions, initial_labels, @@ -18,7 +16,6 @@ def __init__( hnsw_index_update, timeout ): - self.namespace = namespace self.vector_field = vector_field self.dimensions = dimensions self.initial_labels = initial_labels @@ -31,7 +28,6 @@ def __init__( "test_case", [ index_update_test_case( - namespace="test", vector_field="update_2", dimensions=256, initial_labels={"status": "active"}, @@ -48,30 +44,11 @@ def __init__( ), ], ) -def test_index_update(session_admin_client, test_case): - trimmed_random = "saUEN1-" - - # Drop any pre-existing index with the same name - try: - session_admin_client.index_drop(namespace="test", name=trimmed_random) - except AVSServerError as se: - if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: - pass - - # Create the index - session_admin_client.index_create( - namespace=test_case.namespace, - name=trimmed_random, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - index_labels=test_case.initial_labels, - timeout=test_case.timeout, - ) - +def test_index_update(session_admin_client, test_case, index): # Update the index with parameters based on the test case session_admin_client.index_update( - namespace=test_case.namespace, - name=trimmed_random, + namespace=DEFAULT_NAMESPACE, + name=index, index_labels=test_case.update_labels, hnsw_update_params=test_case.hnsw_index_update, timeout=100_000, @@ -80,9 +57,11 @@ def test_index_update(session_admin_client, test_case): time.sleep(10) # Verify the update - result = session_admin_client.index_get(namespace=test_case.namespace, name=trimmed_random, apply_defaults=True) + result = session_admin_client.index_get(namespace=DEFAULT_NAMESPACE, name=index, apply_defaults=True) assert result, "Expected result to be non-empty but got an empty dictionary." + assert result["id"]["namespace"] == DEFAULT_NAMESPACE + # Assertions if test_case.hnsw_index_update.batching_params: assert result["hnsw_params"]["batching_params"]["max_index_records"] == test_case.hnsw_index_update.batching_params.max_index_records @@ -104,6 +83,3 @@ def test_index_update(session_admin_client, test_case): assert result["hnsw_params"]["healer_params"]["max_scan_rate_per_node"] == test_case.hnsw_index_update.healer_params.max_scan_rate_per_node assert result["hnsw_params"]["enable_vector_integrity_check"] == test_case.hnsw_index_update.enable_vector_integrity_check - - # Clean up by dropping the index after the test - drop_specified_index(session_admin_client, test_case.namespace, trimmed_random) diff --git a/tests/standard/sync/test_service_config.py b/tests/standard/sync/test_service_config.py index 3083a74e..b22f1088 100644 --- a/tests/standard/sync/test_service_config.py +++ b/tests/standard/sync/test_service_config.py @@ -1,495 +1,495 @@ -import pytest -import time - -import os -import json - -from aerospike_vector_search import AVSServerError, types -from aerospike_vector_search import AdminClient - - -class service_config_parse_test_case: - def __init__(self, *, service_config_path): - self.service_config_path = service_config_path - - -@pytest.mark.parametrize( - "test_case", - [ - service_config_parse_test_case( - service_config_path="service_configs/master.json" - ), - ], -) -def test_admin_client_service_config_parse( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - test_case, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - with AdminClient( - seeds=types.HostPort(host=host, port=port), - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - service_config_path=test_case.service_config_path, - ssl_target_name_override=ssl_target_name_override, - ) as client: - pass - - -class service_config_test_case: - def __init__( - self, *, service_config_path, namespace, name, vector_field, dimensions - ): - - script_dir = os.path.dirname(os.path.abspath(__file__)) - - self.service_config_path = os.path.abspath( - os.path.join(script_dir, "..", "..", service_config_path) - ) - - with open(self.service_config_path, "rb") as f: - self.service_config = json.load(f) - - self.max_attempts = self.service_config["methodConfig"][0]["retryPolicy"][ - "maxAttempts" - ] - self.initial_backoff = int( - self.service_config["methodConfig"][0]["retryPolicy"]["initialBackoff"][:-1] - ) - self.max_backoff = int( - self.service_config["methodConfig"][0]["retryPolicy"]["maxBackoff"][:-1] - ) - self.backoff_multiplier = self.service_config["methodConfig"][0]["retryPolicy"][ - "backoffMultiplier" - ] - self.retryable_status_codes = self.service_config["methodConfig"][0][ - "retryPolicy" - ]["retryableStatusCodes"] - self.namespace = namespace - self.name = name - self.vector_field = vector_field - self.dimensions = dimensions - - -def calculate_expected_time( - max_attempts, - initial_backoff, - backoff_multiplier, - max_backoff, - retryable_status_codes, -): - - current_backkoff = initial_backoff - - expected_time = 0 - for attempt in range(max_attempts - 1): - expected_time += current_backkoff - current_backkoff *= backoff_multiplier - current_backkoff = min(current_backkoff, max_backoff) - - return expected_time - - -@pytest.mark.parametrize( - "test_case", - [ - service_config_test_case( - service_config_path="service_configs/retries.json", - namespace="test", - name="service_config_index_1", - vector_field="example_1", - dimensions=1024, - ) - ], -) -def test_admin_client_service_config_retries( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - test_case, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - with AdminClient( - seeds=types.HostPort(host=host, port=port), - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - service_config_path=test_case.service_config_path, - ssl_target_name_override=ssl_target_name_override, - - ) as client: - - try: - client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - except: - pass - expected_time = calculate_expected_time( - test_case.max_attempts, - test_case.initial_backoff, - test_case.backoff_multiplier, - test_case.max_backoff, - test_case.retryable_status_codes, - ) - start_time = time.time() - - with pytest.raises(AVSServerError) as e_info: - client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - - end_time = time.time() - elapsed_time = end_time - start_time - - assert abs(elapsed_time - expected_time) < 1.5 - - -@pytest.mark.parametrize( - "test_case", - [ - service_config_test_case( - service_config_path="service_configs/initial_backoff.json", - namespace="test", - name="service_config_index_2", - vector_field="example_1", - dimensions=1024, - ) - ], -) -def test_admin_client_service_config_initial_backoff( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - test_case, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - with AdminClient( - seeds=types.HostPort(host=host, port=port), - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override, - service_config_path=test_case.service_config_path, - ) as client: - - try: - client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - except: - pass - expected_time = calculate_expected_time( - test_case.max_attempts, - test_case.initial_backoff, - test_case.backoff_multiplier, - test_case.max_backoff, - test_case.retryable_status_codes, - ) - start_time = time.time() - - with pytest.raises(AVSServerError) as e_info: - client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - - end_time = time.time() - elapsed_time = end_time - start_time - - assert abs(elapsed_time - expected_time) < 1.5 - - -@pytest.mark.parametrize( - "test_case", - [ - service_config_test_case( - service_config_path="service_configs/max_backoff.json", - namespace="test", - name="service_config_index_3", - vector_field="example_1", - dimensions=1024, - ), - service_config_test_case( - service_config_path="service_configs/max_backoff_lower_than_initial.json", - namespace="test", - name="service_config_index_4", - vector_field="example_1", - dimensions=1024, - ), - ], -) -def test_admin_client_service_config_max_backoff( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - test_case, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - with AdminClient( - seeds=types.HostPort(host=host, port=port), - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override, - service_config_path=test_case.service_config_path, - ) as client: - - try: - client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - except: - pass - expected_time = calculate_expected_time( - test_case.max_attempts, - test_case.initial_backoff, - test_case.backoff_multiplier, - test_case.max_backoff, - test_case.retryable_status_codes, - ) - start_time = time.time() - - with pytest.raises(AVSServerError) as e_info: - client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - - end_time = time.time() - elapsed_time = end_time - start_time - assert abs(elapsed_time - expected_time) < 1.5 - - -@pytest.mark.parametrize( - "test_case", - [ - service_config_test_case( - service_config_path="service_configs/backoff_multiplier.json", - namespace="test", - name="service_config_index_5", - vector_field="example_1", - dimensions=1024, - ) - ], -) -def test_admin_client_service_config_backoff_multiplier( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - test_case, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - with AdminClient( - seeds=types.HostPort(host=host, port=port), - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override, - service_config_path=test_case.service_config_path, - ) as client: - - try: - client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - except: - pass - expected_time = calculate_expected_time( - test_case.max_attempts, - test_case.initial_backoff, - test_case.backoff_multiplier, - test_case.max_backoff, - test_case.retryable_status_codes, - ) - start_time = time.time() - - with pytest.raises(AVSServerError) as e_info: - client.index_create( - namespace=test_case.namespace, - name=test_case.name, - vector_field=test_case.vector_field, - dimensions=test_case.dimensions, - ) - - end_time = time.time() - elapsed_time = end_time - start_time - assert abs(elapsed_time - expected_time) < 1.5 - - -@pytest.mark.parametrize( - "test_case", - [ - service_config_test_case( - service_config_path="service_configs/retryable_status_codes.json", - namespace="test", - name="service_config_index_6", - vector_field=None, - dimensions=None, - ) - ], -) -def test_admin_client_service_config_retryable_status_codes( - host, - port, - username, - password, - root_certificate, - certificate_chain, - private_key, - ssl_target_name_override, - test_case, -): - - if root_certificate: - with open(root_certificate, "rb") as f: - root_certificate = f.read() - - if certificate_chain: - with open(certificate_chain, "rb") as f: - certificate_chain = f.read() - if private_key: - with open(private_key, "rb") as f: - private_key = f.read() - - with AdminClient( - seeds=types.HostPort(host=host, port=port), - username=username, - password=password, - root_certificate=root_certificate, - certificate_chain=certificate_chain, - private_key=private_key, - ssl_target_name_override=ssl_target_name_override, - service_config_path=test_case.service_config_path, - ) as client: - - expected_time = calculate_expected_time( - test_case.max_attempts, - test_case.initial_backoff, - test_case.backoff_multiplier, - test_case.max_backoff, - test_case.retryable_status_codes, - ) - start_time = time.time() - - with pytest.raises(AVSServerError) as e_info: - client.index_get_status( - namespace=test_case.namespace, - name=test_case.name, - ) - - end_time = time.time() - elapsed_time = end_time - start_time - assert abs(elapsed_time - expected_time) < 1.5 +# import pytest +# import time + +# import os +# import json + +# from aerospike_vector_search import AVSServerError, types +# from aerospike_vector_search import AdminClient + + +# class service_config_parse_test_case: +# def __init__(self, *, service_config_path): +# self.service_config_path = service_config_path + + +# @pytest.mark.parametrize( +# "test_case", +# [ +# service_config_parse_test_case( +# service_config_path="service_configs/master.json" +# ), +# ], +# ) +# def test_admin_client_service_config_parse( +# host, +# port, +# username, +# password, +# root_certificate, +# certificate_chain, +# private_key, +# ssl_target_name_override, +# test_case, +# ): + +# if root_certificate: +# with open(root_certificate, "rb") as f: +# root_certificate = f.read() + +# if certificate_chain: +# with open(certificate_chain, "rb") as f: +# certificate_chain = f.read() +# if private_key: +# with open(private_key, "rb") as f: +# private_key = f.read() + +# with AdminClient( +# seeds=types.HostPort(host=host, port=port), +# username=username, +# password=password, +# root_certificate=root_certificate, +# certificate_chain=certificate_chain, +# private_key=private_key, +# service_config_path=test_case.service_config_path, +# ssl_target_name_override=ssl_target_name_override, +# ) as client: +# pass + + +# class service_config_test_case: +# def __init__( +# self, *, service_config_path, namespace, name, vector_field, dimensions +# ): + +# script_dir = os.path.dirname(os.path.abspath(__file__)) + +# self.service_config_path = os.path.abspath( +# os.path.join(script_dir, "..", "..", service_config_path) +# ) + +# with open(self.service_config_path, "rb") as f: +# self.service_config = json.load(f) + +# self.max_attempts = self.service_config["methodConfig"][0]["retryPolicy"][ +# "maxAttempts" +# ] +# self.initial_backoff = int( +# self.service_config["methodConfig"][0]["retryPolicy"]["initialBackoff"][:-1] +# ) +# self.max_backoff = int( +# self.service_config["methodConfig"][0]["retryPolicy"]["maxBackoff"][:-1] +# ) +# self.backoff_multiplier = self.service_config["methodConfig"][0]["retryPolicy"][ +# "backoffMultiplier" +# ] +# self.retryable_status_codes = self.service_config["methodConfig"][0][ +# "retryPolicy" +# ]["retryableStatusCodes"] +# self.namespace = namespace +# self.name = name +# self.vector_field = vector_field +# self.dimensions = dimensions + + +# def calculate_expected_time( +# max_attempts, +# initial_backoff, +# backoff_multiplier, +# max_backoff, +# retryable_status_codes, +# ): + +# current_backkoff = initial_backoff + +# expected_time = 0 +# for attempt in range(max_attempts - 1): +# expected_time += current_backkoff +# current_backkoff *= backoff_multiplier +# current_backkoff = min(current_backkoff, max_backoff) + +# return expected_time + + +# @pytest.mark.parametrize( +# "test_case", +# [ +# service_config_test_case( +# service_config_path="service_configs/retries.json", +# namespace="test", +# name="service_config_index_1", +# vector_field="example_1", +# dimensions=1024, +# ) +# ], +# ) +# def test_admin_client_service_config_retries( +# host, +# port, +# username, +# password, +# root_certificate, +# certificate_chain, +# private_key, +# ssl_target_name_override, +# test_case, +# ): + +# if root_certificate: +# with open(root_certificate, "rb") as f: +# root_certificate = f.read() + +# if certificate_chain: +# with open(certificate_chain, "rb") as f: +# certificate_chain = f.read() +# if private_key: +# with open(private_key, "rb") as f: +# private_key = f.read() + +# with AdminClient( +# seeds=types.HostPort(host=host, port=port), +# username=username, +# password=password, +# root_certificate=root_certificate, +# certificate_chain=certificate_chain, +# private_key=private_key, +# service_config_path=test_case.service_config_path, +# ssl_target_name_override=ssl_target_name_override, + +# ) as client: + +# try: +# client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) +# except: +# pass +# expected_time = calculate_expected_time( +# test_case.max_attempts, +# test_case.initial_backoff, +# test_case.backoff_multiplier, +# test_case.max_backoff, +# test_case.retryable_status_codes, +# ) +# start_time = time.time() + +# with pytest.raises(AVSServerError) as e_info: +# client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) + +# end_time = time.time() +# elapsed_time = end_time - start_time + +# assert abs(elapsed_time - expected_time) < 1.5 + + +# @pytest.mark.parametrize( +# "test_case", +# [ +# service_config_test_case( +# service_config_path="service_configs/initial_backoff.json", +# namespace="test", +# name="service_config_index_2", +# vector_field="example_1", +# dimensions=1024, +# ) +# ], +# ) +# def test_admin_client_service_config_initial_backoff( +# host, +# port, +# username, +# password, +# root_certificate, +# certificate_chain, +# private_key, +# ssl_target_name_override, +# test_case, +# ): + +# if root_certificate: +# with open(root_certificate, "rb") as f: +# root_certificate = f.read() + +# if certificate_chain: +# with open(certificate_chain, "rb") as f: +# certificate_chain = f.read() +# if private_key: +# with open(private_key, "rb") as f: +# private_key = f.read() + +# with AdminClient( +# seeds=types.HostPort(host=host, port=port), +# username=username, +# password=password, +# root_certificate=root_certificate, +# certificate_chain=certificate_chain, +# private_key=private_key, +# ssl_target_name_override=ssl_target_name_override, +# service_config_path=test_case.service_config_path, +# ) as client: + +# try: +# client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) +# except: +# pass +# expected_time = calculate_expected_time( +# test_case.max_attempts, +# test_case.initial_backoff, +# test_case.backoff_multiplier, +# test_case.max_backoff, +# test_case.retryable_status_codes, +# ) +# start_time = time.time() + +# with pytest.raises(AVSServerError) as e_info: +# client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) + +# end_time = time.time() +# elapsed_time = end_time - start_time + +# assert abs(elapsed_time - expected_time) < 1.5 + + +# @pytest.mark.parametrize( +# "test_case", +# [ +# service_config_test_case( +# service_config_path="service_configs/max_backoff.json", +# namespace="test", +# name="service_config_index_3", +# vector_field="example_1", +# dimensions=1024, +# ), +# service_config_test_case( +# service_config_path="service_configs/max_backoff_lower_than_initial.json", +# namespace="test", +# name="service_config_index_4", +# vector_field="example_1", +# dimensions=1024, +# ), +# ], +# ) +# def test_admin_client_service_config_max_backoff( +# host, +# port, +# username, +# password, +# root_certificate, +# certificate_chain, +# private_key, +# ssl_target_name_override, +# test_case, +# ): + +# if root_certificate: +# with open(root_certificate, "rb") as f: +# root_certificate = f.read() + +# if certificate_chain: +# with open(certificate_chain, "rb") as f: +# certificate_chain = f.read() +# if private_key: +# with open(private_key, "rb") as f: +# private_key = f.read() + +# with AdminClient( +# seeds=types.HostPort(host=host, port=port), +# username=username, +# password=password, +# root_certificate=root_certificate, +# certificate_chain=certificate_chain, +# private_key=private_key, +# ssl_target_name_override=ssl_target_name_override, +# service_config_path=test_case.service_config_path, +# ) as client: + +# try: +# client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) +# except: +# pass +# expected_time = calculate_expected_time( +# test_case.max_attempts, +# test_case.initial_backoff, +# test_case.backoff_multiplier, +# test_case.max_backoff, +# test_case.retryable_status_codes, +# ) +# start_time = time.time() + +# with pytest.raises(AVSServerError) as e_info: +# client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) + +# end_time = time.time() +# elapsed_time = end_time - start_time +# assert abs(elapsed_time - expected_time) < 1.5 + + +# @pytest.mark.parametrize( +# "test_case", +# [ +# service_config_test_case( +# service_config_path="service_configs/backoff_multiplier.json", +# namespace="test", +# name="service_config_index_5", +# vector_field="example_1", +# dimensions=1024, +# ) +# ], +# ) +# def test_admin_client_service_config_backoff_multiplier( +# host, +# port, +# username, +# password, +# root_certificate, +# certificate_chain, +# private_key, +# ssl_target_name_override, +# test_case, +# ): + +# if root_certificate: +# with open(root_certificate, "rb") as f: +# root_certificate = f.read() + +# if certificate_chain: +# with open(certificate_chain, "rb") as f: +# certificate_chain = f.read() +# if private_key: +# with open(private_key, "rb") as f: +# private_key = f.read() + +# with AdminClient( +# seeds=types.HostPort(host=host, port=port), +# username=username, +# password=password, +# root_certificate=root_certificate, +# certificate_chain=certificate_chain, +# private_key=private_key, +# ssl_target_name_override=ssl_target_name_override, +# service_config_path=test_case.service_config_path, +# ) as client: + +# try: +# client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) +# except: +# pass +# expected_time = calculate_expected_time( +# test_case.max_attempts, +# test_case.initial_backoff, +# test_case.backoff_multiplier, +# test_case.max_backoff, +# test_case.retryable_status_codes, +# ) +# start_time = time.time() + +# with pytest.raises(AVSServerError) as e_info: +# client.index_create( +# namespace=test_case.namespace, +# name=test_case.name, +# vector_field=test_case.vector_field, +# dimensions=test_case.dimensions, +# ) + +# end_time = time.time() +# elapsed_time = end_time - start_time +# assert abs(elapsed_time - expected_time) < 1.5 + + +# @pytest.mark.parametrize( +# "test_case", +# [ +# service_config_test_case( +# service_config_path="service_configs/retryable_status_codes.json", +# namespace="test", +# name="service_config_index_6", +# vector_field=None, +# dimensions=None, +# ) +# ], +# ) +# def test_admin_client_service_config_retryable_status_codes( +# host, +# port, +# username, +# password, +# root_certificate, +# certificate_chain, +# private_key, +# ssl_target_name_override, +# test_case, +# ): + +# if root_certificate: +# with open(root_certificate, "rb") as f: +# root_certificate = f.read() + +# if certificate_chain: +# with open(certificate_chain, "rb") as f: +# certificate_chain = f.read() +# if private_key: +# with open(private_key, "rb") as f: +# private_key = f.read() + +# with AdminClient( +# seeds=types.HostPort(host=host, port=port), +# username=username, +# password=password, +# root_certificate=root_certificate, +# certificate_chain=certificate_chain, +# private_key=private_key, +# ssl_target_name_override=ssl_target_name_override, +# service_config_path=test_case.service_config_path, +# ) as client: + +# expected_time = calculate_expected_time( +# test_case.max_attempts, +# test_case.initial_backoff, +# test_case.backoff_multiplier, +# test_case.max_backoff, +# test_case.retryable_status_codes, +# ) +# start_time = time.time() + +# with pytest.raises(AVSServerError) as e_info: +# client.index_get_status( +# namespace=test_case.namespace, +# name=test_case.name, +# ) + +# end_time = time.time() +# elapsed_time = end_time - start_time +# assert abs(elapsed_time - expected_time) < 1.5 diff --git a/tests/standard/sync/test_vector_client_delete.py b/tests/standard/sync/test_vector_client_delete.py index 60938a09..50f203d2 100644 --- a/tests/standard/sync/test_vector_client_delete.py +++ b/tests/standard/sync/test_vector_client_delete.py @@ -3,7 +3,7 @@ import grpc from aerospike_vector_search import AVSServerError -from ...utils import random_key +from utils import random_key, DEFAULT_NAMESPACE from hypothesis import given, settings, Verbosity @@ -13,13 +13,11 @@ def __init__( self, *, namespace, - record_data, set_name, timeout, ): self.namespace = namespace self.set_name = set_name - self.record_data = record_data self.timeout = timeout @@ -29,33 +27,25 @@ def __init__( "test_case", [ delete_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, set_name=None, - record_data={"skills": [i for i in range(1024)]}, timeout=None, ), delete_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, set_name=None, - record_data={"english": [float(i) for i in range(1024)]}, timeout=None, ), ], ) -def test_vector_delete(session_vector_client, test_case, random_key): - session_vector_client.upsert( - namespace=test_case.namespace, - key=random_key, - record_data=test_case.record_data, - set_name=test_case.set_name, - ) +def test_vector_delete(session_vector_client, test_case, record): session_vector_client.delete( namespace=test_case.namespace, - key=random_key, + key=record, ) with pytest.raises(AVSServerError) as e_info: result = session_vector_client.get( - namespace=test_case.namespace, key=random_key + namespace=test_case.namespace, key=record ) @@ -65,9 +55,8 @@ def test_vector_delete(session_vector_client, test_case, random_key): "test_case", [ delete_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, set_name=None, - record_data={"skills": [i for i in range(1024)]}, timeout=None, ), ], @@ -85,15 +74,14 @@ def test_vector_delete_without_record(session_vector_client, test_case, random_k "test_case", [ delete_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, set_name=None, - record_data={"skills": [i for i in range(1024)]}, timeout=0.0001, ), ], ) def test_vector_delete_timeout( - session_vector_client, test_case, random_key, with_latency + session_vector_client, test_case, record, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") @@ -101,7 +89,7 @@ def test_vector_delete_timeout( for i in range(10): try: session_vector_client.delete( - namespace=test_case.namespace, key=random_key, timeout=test_case.timeout + namespace=test_case.namespace, key=record, timeout=test_case.timeout ) except AVSServerError as se: if se.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: diff --git a/tests/standard/sync/test_vector_client_exists.py b/tests/standard/sync/test_vector_client_exists.py index fbacf1ca..5be07316 100644 --- a/tests/standard/sync/test_vector_client_exists.py +++ b/tests/standard/sync/test_vector_client_exists.py @@ -1,7 +1,7 @@ import pytest import grpc -from ...utils import random_key +from utils import DEFAULT_NAMESPACE from hypothesis import given, settings, Verbosity from aerospike_vector_search import types, AVSServerError @@ -12,13 +12,11 @@ def __init__( self, *, namespace, - record_data, set_name, timeout, ): self.namespace = namespace self.set_name = set_name - self.record_data = record_data self.timeout = timeout @@ -28,38 +26,24 @@ def __init__( "test_case", [ exists_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, set_name=None, - record_data={"skills": [i for i in range(1024)]}, timeout=None, ), exists_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, set_name=None, - record_data={"english": [float(i) for i in range(1024)]}, timeout=None, ), ], ) -def test_vector_exists(session_vector_client, test_case, random_key): - session_vector_client.upsert( - namespace=test_case.namespace, - key=random_key, - record_data=test_case.record_data, - set_name=test_case.set_name, - timeout=None, - ) +def test_vector_exists(session_vector_client, test_case, record): result = session_vector_client.exists( namespace=test_case.namespace, - key=random_key, + key=record, ) assert result is True - session_vector_client.delete( - namespace=test_case.namespace, - key=random_key, - ) - #@given(random_key=key_strategy()) #@settings(max_examples=1, deadline=1000) @@ -67,12 +51,12 @@ def test_vector_exists(session_vector_client, test_case, random_key): "test_case", [ exists_test_case( - namespace="test", set_name=None, record_data=None, timeout=0.0001 + namespace=DEFAULT_NAMESPACE, set_name=None, timeout=0.0001 ), ], ) def test_vector_exists_timeout( - session_vector_client, test_case, random_key, with_latency + session_vector_client, test_case, record, with_latency ): if not with_latency: pytest.skip("Server latency too low to test timeout") @@ -80,7 +64,7 @@ def test_vector_exists_timeout( for i in range(10): try: result = session_vector_client.exists( - namespace=test_case.namespace, key=random_key, timeout=test_case.timeout + namespace=test_case.namespace, key=record, timeout=test_case.timeout ) except AVSServerError as se: if se.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: diff --git a/tests/standard/sync/test_vector_client_get.py b/tests/standard/sync/test_vector_client_get.py index ab26c1f1..b7fc1c48 100644 --- a/tests/standard/sync/test_vector_client_get.py +++ b/tests/standard/sync/test_vector_client_get.py @@ -1,11 +1,10 @@ +from aerospike_vector_search import types, AVSServerError +from utils import DEFAULT_NAMESPACE, random_key + import pytest import grpc -from ...utils import random_key - from hypothesis import given, settings, Verbosity -from aerospike_vector_search import types, AVSServerError - class get_test_case: def __init__( @@ -15,7 +14,6 @@ def __init__( include_fields, exclude_fields, set_name, - record_data, expected_fields, timeout, ): @@ -23,7 +21,6 @@ def __init__( self.include_fields = include_fields self.exclude_fields = exclude_fields self.set_name = set_name - self.record_data = record_data self.expected_fields = expected_fields self.timeout = timeout @@ -31,99 +28,103 @@ def __init__( #@given(random_key=key_strategy()) #@settings(max_examples=1, deadline=1000) @pytest.mark.parametrize( - "test_case", + "record,test_case", [ - get_test_case( - namespace="test", - include_fields=["skills"], - exclude_fields=None, - set_name=None, - record_data={"skills": [i for i in range(1024)]}, - expected_fields={"skills": [i for i in range(1024)]}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"skills": 1024}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=["skills"], + exclude_fields=None, + set_name=None, + expected_fields={"skills": 1024}, + timeout=None, + ), ), - get_test_case( - namespace="test", - include_fields=["english"], - exclude_fields=None, - set_name=None, - record_data={"english": [float(i) for i in range(1024)]}, - expected_fields={"english": [float(i) for i in range(1024)]}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": [float(i) for i in range(1024)]}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=["english"], + exclude_fields=None, + set_name=None, + expected_fields={"english": [float(i) for i in range(1024)]}, + timeout=None, + ), ), - get_test_case( - namespace="test", - include_fields=["english"], - exclude_fields = None, - set_name=None, - record_data={"english": 1, "spanish": 2}, - expected_fields={"english": 1}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": 1, "spanish": 2}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=["english"], + exclude_fields=None, + set_name=None, + expected_fields={"english": 1}, + timeout=None, + ), ), - get_test_case( - namespace="test", - include_fields=None, - exclude_fields=["spanish"], - set_name=None, - record_data={"english": 1, "spanish": 2}, - expected_fields={"english": 1}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": 1, "spanish": 2}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=None, + exclude_fields=["spanish"], + set_name=None, + expected_fields={"english": 1}, + timeout=None, + ), ), - get_test_case( - namespace="test", - include_fields=["spanish"], - exclude_fields=["spanish"], - set_name=None, - record_data={"english": 1, "spanish": 2}, - expected_fields={}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": 1, "spanish": 2}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=["spanish"], + exclude_fields=["spanish"], + set_name=None, + expected_fields={}, + timeout=None, + ), ), - get_test_case( - namespace="test", - include_fields=[], - exclude_fields=None, - set_name=None, - record_data={"english": 1, "spanish": 2}, - expected_fields={}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": 1, "spanish": 2}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=[], + exclude_fields=None, + set_name=None, + expected_fields={}, + timeout=None, + ), ), - get_test_case( - namespace="test", - include_fields=None, - exclude_fields=[], - set_name=None, - record_data={"english": 1, "spanish": 2}, - expected_fields={"english": 1, "spanish": 2}, - timeout=None, + ( + {"record_generator": lambda count, vec_bin, vec_dim: (yield ("key1", {"english": 1, "spanish": 2}))}, + get_test_case( + namespace=DEFAULT_NAMESPACE, + include_fields=None, + exclude_fields=[], + set_name=None, + expected_fields={"english": 1, "spanish": 2}, + timeout=None, + ), ), ], + indirect=["record"], ) -def test_vector_get(session_vector_client, test_case, random_key): - session_vector_client.upsert( - namespace=test_case.namespace, - key=random_key, - record_data=test_case.record_data, - set_name=test_case.set_name, - ) +def test_vector_get(session_vector_client, test_case, record): result = session_vector_client.get( namespace=test_case.namespace, - key=random_key, + key=record, include_fields=test_case.include_fields, exclude_fields=test_case.exclude_fields, ) assert result.key.namespace == test_case.namespace - if test_case.set_name == None: + if test_case.set_name is None: test_case.set_name = "" assert result.key.set == test_case.set_name - assert result.key.key == random_key + assert result.key.key == record assert result.fields == test_case.expected_fields - session_vector_client.delete( - namespace=test_case.namespace, - key=random_key, - ) - #@given(random_key=key_strategy()) #@settings(max_examples=1, deadline=1000) @@ -131,11 +132,10 @@ def test_vector_get(session_vector_client, test_case, random_key): "test_case", [ get_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, include_fields=["skills"], exclude_fields=None, set_name=None, - record_data=None, expected_fields=None, timeout=0.0001, ), diff --git a/tests/standard/sync/test_vector_client_insert.py b/tests/standard/sync/test_vector_client_insert.py index 5c1a39e0..b763b9fa 100644 --- a/tests/standard/sync/test_vector_client_insert.py +++ b/tests/standard/sync/test_vector_client_insert.py @@ -2,7 +2,7 @@ from aerospike_vector_search import AVSServerError import grpc -from ...utils import random_key +from utils import random_key, DEFAULT_NAMESPACE from hypothesis import given, settings, Verbosity @@ -26,21 +26,21 @@ def __init__( "test_case", [ insert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, ignore_mem_queue_full=None, timeout=None, ), insert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"homeSkills": [float(i) for i in range(1024)]}, set_name=None, ignore_mem_queue_full=None, timeout=None, ), insert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"english": [bool(i) for i in range(1024)]}, set_name=None, ignore_mem_queue_full=None, @@ -75,7 +75,7 @@ def test_vector_insert_without_existing_record( "test_case", [ insert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, ignore_mem_queue_full=None, @@ -84,25 +84,15 @@ def test_vector_insert_without_existing_record( ], ) def test_vector_insert_with_existing_record( - session_vector_client, test_case, random_key + session_vector_client, test_case, record ): - session_vector_client.upsert( - namespace=test_case.namespace, - key=random_key, - record_data=test_case.record_data, - set_name=test_case.set_name, - ) with pytest.raises(AVSServerError) as e_info: session_vector_client.insert( namespace=test_case.namespace, - key=random_key, + key=record, record_data=test_case.record_data, set_name=test_case.set_name, ) - session_vector_client.delete( - namespace=test_case.namespace, - key=random_key, - ) #@given(random_key=key_strategy()) @@ -111,7 +101,7 @@ def test_vector_insert_with_existing_record( "test_case", [ insert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"english": [bool(i) for i in range(1024)]}, set_name=None, ignore_mem_queue_full=True, @@ -147,7 +137,7 @@ def test_vector_insert_without_existing_record_ignore_mem_queue_full( "test_case", [ insert_test_case( - namespace="test", + namespace=DEFAULT_NAMESPACE, record_data={"math": [i for i in range(1024)]}, set_name=None, ignore_mem_queue_full=None, @@ -170,6 +160,10 @@ def test_vector_insert_timeout( set_name=test_case.set_name, timeout=test_case.timeout, ) + session_vector_client.delete( + namespace=test_case.namespace, + key=random_key, + ) except AVSServerError as e: if e.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: assert e.rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED diff --git a/tests/standard/sync/test_vector_client_is_indexed.py b/tests/standard/sync/test_vector_client_is_indexed.py index 4f761957..0c4baa97 100644 --- a/tests/standard/sync/test_vector_client_is_indexed.py +++ b/tests/standard/sync/test_vector_client_is_indexed.py @@ -1,76 +1,46 @@ import pytest -import random + from aerospike_vector_search import AVSServerError -from aerospike_vector_search import Client, AdminClient +from utils import random_name, DEFAULT_NAMESPACE +from .sync_utils import wait_for_index import grpc -# Define module-level constants for common arguments -NAMESPACE = "test" -SET_NAME = "isidxset" -INDEX_NAME = "isidx" -DIMENSIONS = 3 -VECTOR_FIELD = "vector" -@pytest.fixture(scope="module", autouse=True) -async def setup_teardown(session_vector_client: Client, session_admin_client: AdminClient): - # Setup: Create index and upsert records - session_admin_client.index_create( - namespace=NAMESPACE, - sets=SET_NAME, - name=INDEX_NAME, - dimensions=DIMENSIONS, - vector_field=VECTOR_FIELD, - ) - for i in range(10): - session_vector_client.upsert( - namespace=NAMESPACE, - set_name=SET_NAME, - key=str(i), - record_data={VECTOR_FIELD: [float(i)] * DIMENSIONS}, - ) - session_vector_client.wait_for_index_completion( - namespace=NAMESPACE, - name=INDEX_NAME, - ) - yield - # Teardown: remove records - for i in range(10): - session_vector_client.delete( - namespace=NAMESPACE, - set_name=SET_NAME, - key=str(i) - ) - - # Teardown: Drop index - session_admin_client.index_drop( - namespace=NAMESPACE, - name=INDEX_NAME - ) - -async def test_vector_is_indexed( +def test_vector_is_indexed( + session_admin_client, session_vector_client, + index, + record, ): + # wait for the record to be indexed + wait_for_index( + admin_client=session_admin_client, + namespace=DEFAULT_NAMESPACE, + index=index + ) + result = session_vector_client.is_indexed( - namespace=NAMESPACE, - key="0", - index_name=INDEX_NAME, - set_name=SET_NAME, + namespace=DEFAULT_NAMESPACE, + key=record, + index_name=index, ) assert result is True -async def test_vector_is_indexed_timeout( - session_vector_client, with_latency +def test_vector_is_indexed_timeout( + session_vector_client, + with_latency, + random_name, ): if not with_latency: pytest.skip("Server latency too low to test timeout") for _ in range(10): try: session_vector_client.is_indexed( - namespace=NAMESPACE, + namespace=DEFAULT_NAMESPACE, key="0", - index_name=INDEX_NAME, + index_name=random_name, timeout=0.0001, ) except AVSServerError as se: diff --git a/tests/standard/sync/test_vector_client_search_by_key.py b/tests/standard/sync/test_vector_client_search_by_key.py index f46c401e..044a0ba0 100644 --- a/tests/standard/sync/test_vector_client_search_by_key.py +++ b/tests/standard/sync/test_vector_client_search_by_key.py @@ -301,6 +301,19 @@ def test_vector_search_by_key( name=test_case.index_name, vector_field=test_case.vector_field, dimensions=test_case.index_dimensions, + index_params=types.HnswParams( + batching_params=types.HnswBatchingParams( + # 10_000 is the minimum value, in order for the tests to run as + # fast as possible we set it to the minimum value so records are indexed + # quickly + index_interval=10_000, + ), + healer_params=types.HnswHealerParams( + # run the healer every second + # for fast indexing + schedule="* * * * * ?" + ) + ) ) for key, rec in test_case.record_data.items(): diff --git a/tests/standard/sync/test_vector_client_update.py b/tests/standard/sync/test_vector_client_update.py index ef111b49..c11fe643 100644 --- a/tests/standard/sync/test_vector_client_update.py +++ b/tests/standard/sync/test_vector_client_update.py @@ -41,19 +41,11 @@ def __init__(self, *, namespace, record_data, set_name, timeout): ], ) def test_vector_update_with_existing_record( - session_vector_client, test_case, random_key + session_vector_client, test_case, record ): - session_vector_client.upsert( - namespace=test_case.namespace, - key=random_key, - record_data=test_case.record_data, - set_name=test_case.set_name, - timeout=None, - ) - session_vector_client.update( namespace=test_case.namespace, - key=random_key, + key=record, record_data=test_case.record_data, set_name=test_case.set_name, timeout=None, diff --git a/tests/standard/sync/test_vector_search.py b/tests/standard/sync/test_vector_search.py index 5648c6fb..415e038c 100644 --- a/tests/standard/sync/test_vector_search.py +++ b/tests/standard/sync/test_vector_search.py @@ -39,7 +39,7 @@ def __init__( vector_search_test_case( index_name="basic_search", index_dimensions=3, - vector_field="vector", + vector_field="vecs", limit=3, query=[0.0, 0.0, 0.0], namespace="test", @@ -49,7 +49,7 @@ def __init__( record_data={ "rec1": { "bin1": 1, - "vector": [1.0, 1.0, 1.0], + "vecs": [1.0, 1.0, 1.0], }, }, expected_results=[ @@ -61,7 +61,7 @@ def __init__( ), fields={ "bin1": 1, - "vector": [1.0, 1.0, 1.0], + "vecs": [1.0, 1.0, 1.0], }, distance=3.0, ), @@ -70,7 +70,7 @@ def __init__( vector_search_test_case( index_name="field_filter", index_dimensions=3, - vector_field="vector", + vector_field="vecs", limit=3, query=[0.0, 0.0, 0.0], namespace="test", @@ -80,7 +80,7 @@ def __init__( record_data={ "rec1": { "bin1": 1, - "vector": [1.0, 1.0, 1.0], + "vecs": [1.0, 1.0, 1.0], }, }, expected_results=[ diff --git a/tests/utils.py b/tests/utils.py index 6d3d222c..ddcb1306 100755 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,6 +5,13 @@ import pytest + +# default test values +DEFAULT_NAMESPACE = "test" +DEFAULT_INDEX_DIMENSION = 128 +DEFAULT_VECTOR_FIELD = "vector" + + def random_int(): return str(random.randint(0, 50_000))