forked from milvus-io/bootcamp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
performance_test.py
92 lines (81 loc) · 3.79 KB
/
performance_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import time
import numpy as np
from logs import LOGGER
from config import QUERY_FILE_PATH, PERFORMANCE_RESULTS_PATH, NQ_SCOPE, TOPK_SCOPE, METRIC_TYPE, PERCENTILE_NUM
def get_search_params(search_param, index_type):
if index_type == 'FLAT':
search_params = {"metric_type": METRIC_TYPE}
elif index_type == 'RNSG':
search_params = {"metric_type": METRIC_TYPE, "params": {'search_length': search_param}}
elif index_type == 'HNSW':
search_params = {"metric_type": METRIC_TYPE, "params": {'ef': search_param}}
elif index_type == 'ANNOY':
search_params = {"metric_type": METRIC_TYPE, "params": {"search_k": search_param}}
else:
search_params = {"metric_type": METRIC_TYPE, "params": {"nprobe": search_param}}
# search_params = {'nprobe': search_param}
print(search_params)
return search_params
def get_nq_vec(query):
data = np.load(QUERY_FILE_PATH)
if len(data) > query:
return data[0:query].tolist()
else:
LOGGER.info(f'There is only {len(data)} vectors')
return data.tolist()
def performance(client, collection_name, search_param):
index_type = client.get_index_params(collection_name)
if index_type:
index_type = index_type[0]['index_type']
else:
index_type = 'FLAT'
search_params = get_search_params(search_param, index_type)
if not os.path.exists(PERFORMANCE_RESULTS_PATH):
os.mkdir(PERFORMANCE_RESULTS_PATH)
result_filename = collection_name + '_' + str(search_param) + '_performance.csv'
performance_file = os.path.join(PERFORMANCE_RESULTS_PATH, result_filename)
with open(performance_file, 'w+', encoding='utf-8') as f:
f.write("nq,topk,total_time,avg_time" + '\n')
for nq in NQ_SCOPE:
query_list = get_nq_vec(nq)
LOGGER.info(f"begin to search, nq = {len(query_list)}")
for topk in TOPK_SCOPE:
time_start = time.time()
client.search_vectors(collection_name, query_list, topk, search_params)
time_cost = time.time() - time_start
print(nq, topk, time_cost)
line = str(nq) + ',' + str(topk) + ',' + str(round(time_cost, 4)) + ',' + str(
round(time_cost / nq, 4)) + '\n'
f.write(line)
f.write('\n')
LOGGER.info("search_vec_list done !")
def percentile_test(client, collection_name, search_param, percentile):
index_type = client.get_index_params(collection_name)
if index_type:
index_type = index_type[0]['index_type']
else:
index_type = 'FLAT'
search_params = get_search_params(search_param, index_type)
if not os.path.exists(PERFORMANCE_RESULTS_PATH):
os.mkdir(PERFORMANCE_RESULTS_PATH)
result_filename = collection_name + '_' + str(search_param) + '_percentile.csv'
performance_file = os.path.join(PERFORMANCE_RESULTS_PATH, result_filename)
with open(performance_file, 'w+', encoding='utf-8') as f:
f.write("nq,topk,total_time" + '\n')
for nq in NQ_SCOPE:
query_list = get_nq_vec(nq)
LOGGER.info(f"begin to search, nq = {len(query_list)}")
for topk in TOPK_SCOPE:
time_cost = []
for _ in range(PERCENTILE_NUM):
time_start = time.time()
client.search_vectors(collection_name, query_list, topk, search_params)
time_cost.append(time.time() - time_start)
time_cost = np.array(time_cost)
time_cost = np.percentile(time_cost, float(percentile))
print(nq, topk, round(time_cost, 4))
line = str(nq) + ',' + str(topk) + ',' + str(round(time_cost, 4)) + '\n'
f.write(line)
f.write('\n')
LOGGER.info("search_vec_list done !")