Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix IndexIVFFastScan reconstruct_from_offset method #4095

Closed
20 changes: 8 additions & 12 deletions faiss/IndexIVFFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1353,34 +1353,30 @@ void IndexIVFFastScan::reconstruct_from_offset(
int64_t offset,
float* recons) const {
// unpack codes
size_t coarse_size = coarse_code_size();
std::vector<uint8_t> code(coarse_size + code_size, 0);
encode_listno(list_no, code.data());
InvertedLists::ScopedCodes list_codes(invlists, list_no);
std::vector<uint8_t> code(code_size, 0);
BitstringWriter bsw(code.data(), code_size);
BitstringWriter bsw(code.data() + coarse_size, code_size);

for (size_t m = 0; m < M; m++) {
uint8_t c =
pq4_get_packed_element(list_codes.get(), bbs, M2, offset, m);
bsw.write(c, nbits);
}
sa_decode(1, code.data(), recons);

// add centroid to it
if (by_residual) {
std::vector<float> centroid(d);
quantizer->reconstruct(list_no, centroid.data());
for (int i = 0; i < d; ++i) {
recons[i] += centroid[i];
}
}
sa_decode(1, code.data(), recons);
}

void IndexIVFFastScan::reconstruct_orig_invlists() {
FAISS_THROW_IF_NOT(orig_invlists != nullptr);
FAISS_THROW_IF_NOT(orig_invlists->list_size(0) == 0);

#pragma omp parallel for if (nlist > 100)
for (size_t list_no = 0; list_no < nlist; list_no++) {
InvertedLists::ScopedCodes codes(invlists, list_no);
InvertedLists::ScopedIds ids(invlists, list_no);
size_t list_size = orig_invlists->list_size(list_no);
size_t list_size = invlists->list_size(list_no);
std::vector<uint8_t> code(code_size, 0);

for (size_t offset = 0; offset < list_size; offset++) {
Expand Down
1 change: 1 addition & 0 deletions faiss/IndexIVFPQFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs)
precomputed_table.nbytes());
}

#pragma omp parallel for if (nlist > 100)
for (size_t i = 0; i < nlist; i++) {
size_t nb = orig.invlists->list_size(i);
size_t nb2 = roundup(nb, bbs);
Expand Down
31 changes: 31 additions & 0 deletions tests/test_fast_scan_ivf.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,37 @@ def test_by_residual_odd_dim(self):
self.do_test(by_residual=True, d=30)


class TestReconstruct(unittest.TestCase):

def do_test(self, by_residual=False):
d = 32
metric = faiss.METRIC_L2

ds = datasets.SyntheticDataset(d, 2000, 5000, 200)

index = faiss.IndexIVFPQFastScan(faiss.IndexFlatL2(d), d, 50, d // 2, 4, metric)
index.by_residual = by_residual
index.make_direct_map(True)
index.train(ds.get_train())
index.add(ds.get_database())

# Test reconstruction
index.reconstruct(123) # single id
index.reconstruct_n(123, 10) # single id
index.reconstruct_batch(np.arange(10))

# Test original list reconstruction
index.orig_invlists = faiss.ArrayInvertedLists(index.nlist, index.code_size)
index.reconstruct_orig_invlists()
assert index.orig_invlists.compute_ntotal() == index.ntotal

def test_no_residual(self):
self.do_test(by_residual=False)

def test_by_residual(self):
self.do_test(by_residual=True)


class TestIsTrained(unittest.TestCase):

def test_issue_2019(self):
Expand Down
Loading