-
Notifications
You must be signed in to change notification settings - Fork 0
/
faiss_index.py
99 lines (77 loc) · 2.88 KB
/
faiss_index.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import faiss
from faiss import cast_integer_to_float_ptr as cast_float
from faiss import cast_integer_to_int_ptr as cast_int
from faiss import cast_integer_to_long_ptr as cast_long
from util import *
class FAISSIndex(object):
def __init__(self, cell_size=20, nr_cells=1024, K=4, num_lists=32, probes=32, res=None, train=None, gpu_id=-1):
super(FAISSIndex, self).__init__()
self.cell_size = cell_size
self.nr_cells = nr_cells
self.probes = probes
self.K = K
self.num_lists = num_lists
self.gpu_id = gpu_id
# BEWARE: if this variable gets deallocated, FAISS crashes
if res:
self.res = res
print("common res object used")
else:
self.res = faiss.StandardGpuResources()
self.res.setTempMemoryFraction(0.01)
if self.gpu_id != -1:
self.res.initializeForDevice(self.gpu_id)
# self.res = res if res else faiss.StandardGpuResources()
# self.res.setTempMemoryFraction(0.01)
# if self.gpu_id != -1:
# if res == None:
# self.res.initializeForDevice(self.gpu_id)
#nr_samples = self.nr_cells * 100 * self.cell_size
#train = train if train is not None else T.randn(self.nr_cells * 100, self.cell_size)
# self.index = faiss.GpuIndexIVFFlat(self.res, self.cell_size, self.num_lists, faiss.METRIC_L2)
# self.index.setNumProbes(self.probes)
# self.train(train)
self.index = faiss.GpuIndexFlatL2(self.res, self.cell_size)
def cuda(self, gpu_id):
self.gpu_id = gpu_id
def train(self, train):
train = ensure_gpu(train, -1)
T.cuda.synchronize()
self.index.train_c(self.nr_cells, cast_float(ptr(train)))
T.cuda.synchronize()
def reset(self):
T.cuda.synchronize()
self.index.reset()
T.cuda.synchronize()
def add(self, other, positions=None, last=None):
other = ensure_gpu(other, self.gpu_id)
T.cuda.synchronize()
if positions is not None:
positions = ensure_gpu(positions, self.gpu_id)
assert positions.size(0) == other.size(0), "Mismatch in number of positions and vectors"
self.index.add_with_ids_c(other.size(0), cast_float(ptr(other)), cast_long(ptr(positions + 1)))
else:
other = other[:last, :] if last is not None else other
self.index.add_c(other.size(0), cast_float(ptr(other)))
T.cuda.synchronize()
def search(self, query, k=None):
query = ensure_gpu(query, self.gpu_id)
k = k if k else self.K
(b,n) = query.size()
distances = T.FloatTensor(b, k)
labels = T.LongTensor(b, k)
if self.gpu_id != -1: distances = distances.cuda(self.gpu_id)
if self.gpu_id != -1: labels = labels.cuda(self.gpu_id)
T.cuda.synchronize()
self.index.search_c(
b,
cast_float(ptr(query)),
k,
cast_float(ptr(distances)),
cast_long(ptr(labels))
)
T.cuda.synchronize()
#print(distances)
return (distances, (labels-1))