-
Notifications
You must be signed in to change notification settings - Fork 1
/
mixed_import.py
158 lines (124 loc) · 4.23 KB
/
mixed_import.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# import face_recognition
import os
import time
from milvus import *
import psycopg2
import numpy as np
import random
from faker import Faker
fake = Faker()
MILVUS_collection = 'mixe_query'
PG_TABLE_NAME = 'mixe_query'
FILE_PATH = 'bigann_base.bvecs'
VEC_NUM = 100000000
BASE_LEN = 100000
VEC_DIM = 128
SERVER_ADDR = "127.0.0.1"
SERVER_PORT = 19530
PG_HOST = "192.168.1.10"
PG_PORT = 5432
PG_USER = "postgres"
PG_PASSWORD = "postgres"
PG_DATABASE = "postgres"
# milvus = Milvus()
def load_bvecs_data(fname,base_len,idx):
begin_num = base_len * idx
# print(fname, ": ", begin_num )
x = np.memmap(fname, dtype='uint8', mode='r')
d = x[:4].view('int32')[0]
data = x.reshape(-1, d + 4)[begin_num:(begin_num+base_len), 4:]
data = (data + 0.5) / 255
# data = normaliz_data(data)
data = data.tolist()
return data
def create_milvus_collection(milvus):
if not milvus.has_collection(MILVUS_collection)[1]:
param = {
'collection_name': MILVUS_collection,
'dimension': VEC_DIM,
'index_file_size':1024,
'metric_type':MetricType.L2
}
milvus.create_collection(param)
def build_collection(milvus):
index_param = {'nlist': 16384}
status = milvus.create_index(MILVUS_collection,IndexType.IVF_SQ8H,index_param)
print(status)
def connect_postgres_server():
try:
conn = psycopg2.connect(host=PG_HOST, port=PG_PORT, user=PG_USER, password=PG_PASSWORD,database=PG_DATABASE)
return conn
except:
print ("unable to connect to the database")
def create_pg_table(conn,cur):
try:
sql = "CREATE TABLE " + PG_TABLE_NAME + " (ids bigint, sex char(10), get_time timestamp, is_glasses boolean);"
cur.execute(sql)
conn.commit()
print("create postgres table!")
except:
print("can't create postgres table")
def insert_data_to_pg(ids, vector, sex, get_time, is_glasses, conn, cur):
sql = "INSERT INTO " + PG_TABLE_NAME + " VALUES(" + str(ids) + ", array" + str(vector) + ", '" + str(sex) + "', '" + str(get_time) + "', '" + str(is_glasses) + "');"
# print(sql)
try:
# print(sql)
cur.execute(sql)
conn.commit()
# print("insert success!")
except:
print("faild insert")
def copy_data_to_pg(conn, cur):
fname = 'temp.csv'
csv_path = os.path.join(os.getcwd(),fname)
sql = "copy " + PG_TABLE_NAME + " from '" + csv_path + "' with CSV delimiter '|';"
# print(sql)
try:
cur.execute(sql)
conn.commit()
print("insert pg sucessful!")
except:
print("faild copy!")
def build_pg_index(conn,cur):
try:
sql = "CREATE INDEX index_ids on " + PG_TABLE_NAME + "(ids);"
cur.execute(sql)
conn.commit()
print("build index sucessful!")
except:
print("faild build index")
def record_txt(ids):
fname = 'temp.csv'
with open(fname,'w+') as f:
for i in range(len(ids)):
sex = random.choice(['female','male'])
get_time = fake.past_datetime(start_date="-120d", tzinfo=None)
is_glasses = random.choice(['True','False'])
line = str(ids[i]) + "|" + sex + "|'" + str(get_time) + "'|" + str(is_glasses) + "\n"
f.write(line)
def main():
# connect_milvus_server()
milvus = Milvus(host=SERVER_ADDR, port=SERVER_PORT)
create_milvus_collection(milvus)
build_collection(milvus)
conn = connect_postgres_server()
cur = conn.cursor()
create_pg_table(conn,cur)
count = 0
while count < (VEC_NUM // BASE_LEN):
vectors = load_bvecs_data(FILE_PATH,BASE_LEN,count)
vectors_ids = [id for id in range(count*BASE_LEN,(count+1)*BASE_LEN)]
time_start = time.time()
status, ids = milvus.insert(collection_name=MILVUS_collection, records=vectors, ids=vectors_ids)
time_end = time.time()
print(count, "insert milvue time: ", time_end-time_start)
# print(count)
time_start = time.time()
record_txt(ids)
copy_data_to_pg(conn, cur)
time_end = time.time()
print(count, "insert pg time: ", time_end-time_start)
count = count + 1
build_pg_index(conn,cur)
if __name__ == '__main__':
main()