-
Notifications
You must be signed in to change notification settings - Fork 4
/
db_upload.py
71 lines (53 loc) · 3.02 KB
/
db_upload.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import psycopg
from pgvector.psycopg import register_vector
#from postgis.psycopg import register
class DBUpload:
DB_NAME= 'geoimage'
def __init__(self, vector_size: int, table_name: str):
self.vector_size = vector_size
self.table_name = table_name
conn = psycopg.connect("host=localhost user=postgres password='letmein'", autocommit=True)
cursor = conn.cursor()
cursor.execute("SELECT datname FROM pg_database;")
list_database = cursor.fetchall()
if (self.DB_NAME,) in list_database:
cursor.execute(("DROP database "+ self.DB_NAME +" with (FORCE);"))
cursor.execute("create database " + self.DB_NAME + ";")
else:
cursor.execute("create database " + self.DB_NAME + ";")
#Now close the connection and switch DB
conn.close()
connect_string = f"host=localhost user=postgres password='letmein' dbname='{self.DB_NAME}'"
conn = psycopg.connect(connect_string, autocommit=True)
conn.execute('CREATE EXTENSION IF NOT EXISTS vector')
conn.execute('CREATE EXTENSION IF NOT EXISTS postgis')
conn.close()
def upsert_vectors(self, ids, vectors, payloads):
connect_string = f"host=localhost user=postgres password='letmein' dbname='{self.DB_NAME}'"
conn = psycopg.connect(connect_string, autocommit=True)
register_vector(conn)
#register(conn)
conn.execute('DROP TABLE IF EXISTS %s' % self.table_name)
# ID is autogenerated and all the other columns besides embedding are in the payload
conn.execute("""CREATE TABLE %s (id bigserial PRIMARY KEY,
filename text,
picture text,
url text,
location geography(POINT,4326),
embedding vector(%s))""" % (self.table_name, self.vector_size,))
conn.commit()
# Copy in spatial data ST_Point(location["lon"), location["lat"])
with conn.cursor().copy("COPY %s (filename, picture, url, location, embedding) FROM STDIN" % (self.table_name)) as copy:
for i in range (0,len(vectors)):
location = "POINT( %s %s)" % (payloads[i]["location"]["lon"], payloads[i]["location"]["lat"])
copy.write_row([payloads[i]["filename"], payloads[i]["picture"], payloads[i]["url"], location, vectors[i]])
# create spatial and hnsw indices
print("making spatial index\n")
conn.execute("CREATE INDEX %s_location_idx on %s USING GIST(location)" % (self.table_name, self.table_name))
conn.commit()
print("creating HNSW index")
conn.execute("set maintenance_work_mem to '350MB'")
conn.execute("""CREATE INDEX idx_%s_hnsw ON %s USING hnsw
(embedding vector_cosine_ops) WITH (m = 10, ef_construction = 40)""" % (self.table_name, self.table_name))
conn.commit()
conn.close()