Skip to content

Commit

Permalink
fix import, no heap test, linting
Browse files Browse the repository at this point in the history
  • Loading branch information
gustavz committed Dec 16, 2024
1 parent d70e64b commit 539dfae
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 22 deletions.
3 changes: 2 additions & 1 deletion faiss/IndexBinaryFlat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ void IndexBinaryFlat::range_search(
RangeSearchResult* result,
const SearchParameters* params) const {
const IDSelector* sel = params ? params->sel : nullptr;
hamming_range_search(x, xb.data(), n, ntotal, radius, code_size, result, sel);
hamming_range_search(
x, xb.data(), n, ntotal, radius, code_size, result, sel);
}

} // namespace faiss
8 changes: 5 additions & 3 deletions faiss/utils/approx_topk_hamming/approx_topk_hamming.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,19 @@ struct HeapWithBucketsForHamming32<
uint32_t hamming_distances[8];
uint8_t valid_counter = 0;
for (size_t j8 = 0; j8 < 8; j8++) {
const uint32_t idx = j8 + j * 8 + ip + n_per_beam * beam_index;
const uint32_t idx =
j8 + j * 8 + ip + n_per_beam * beam_index;
if (!sel || sel->is_member(idx)) {
hamming_distances[j8] = hc.hamming(
binary_vectors + idx * code_size);
valid_counter++;
} else {
hamming_distances[j8] = std::numeric_limits<int32_t>::max();
hamming_distances[j8] =
std::numeric_limits<int32_t>::max();
}
}

if (valid_counter == 0) {
if (valid_counter == 8) {
continue; // Skip if all vectors are filtered out
}

Expand Down
16 changes: 13 additions & 3 deletions faiss/utils/hamming.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ void hammings_knn_hc(
size_t n2,
bool order = true,
bool init_heap = true,
ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK
ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK,
const IDSelector* sel = nullptr) {
size_t k = ha->k;
if (init_heap)
Expand Down Expand Up @@ -500,11 +500,21 @@ void hammings_knn_hc(
size_t nb,
size_t ncodes,
int order,
ApproxTopK_mode_t approx_topk_mode
ApproxTopK_mode_t approx_topk_mode,
const IDSelector* sel) {
Run_hammings_knn_hc r;
dispatch_HammingComputer(
ncodes, r, ncodes, ha, a, b, nb, order, true, approx_topk_mode, sel);
ncodes,
r,
ncodes,
ha,
a,
b,
nb,
order,
true,
approx_topk_mode,
sel);
}

void hammings_knn_mc(
Expand Down
1 change: 1 addition & 0 deletions faiss/utils/hamming.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <stdint.h>

#include <faiss/impl/IDSelector.h>
#include <faiss/impl/platform_macros.h>
#include <faiss/utils/Heap.h>

Expand Down
23 changes: 8 additions & 15 deletions tests/test_search_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@ class TestSelector(unittest.TestCase):
combinations as possible.
"""

def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METRIC_L2, k=10, params=None):
def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METRIC_L2, k=10, use_heap=True):
""" Verify that the id selector returns the subset of results that are
members according to the IDSelector.
Supports id_selector_type="batch", "bitmap", "range", "range_sorted", "and", "or", "xor"
params: optional SearchParameters object to override default settings
"""
d = 32 # make sure dimension is multiple of 8 for binary
ds = datasets.SyntheticDataset(d, 1000, 100, 20)
Expand All @@ -38,6 +37,7 @@ def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METR
xq = rs.randint(256, size=(ds.nq, d // 8), dtype='uint8')
xt = None # No training needed for binary flat
index = faiss.IndexBinaryFlat(d)
index.use_heap = use_heap
# Use smaller radius for Hamming distance
base_radius = 4
else:
Expand Down Expand Up @@ -137,15 +137,11 @@ def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METR
else:
sel = faiss.IDSelectorBatch(subset)

if params is None:
params = (
faiss.SearchParametersIVF(sel=sel) if "IVF" in index_key else
faiss.SearchParametersPQ(sel=sel) if "PQ" in index_key else
faiss.SearchParameters(sel=sel)
)
else:
# Use provided params but ensure selector is set
params.sel = sel
params = (
faiss.SearchParametersIVF(sel=sel) if "IVF" in index_key else
faiss.SearchParametersPQ(sel=sel) if "PQ" in index_key else
faiss.SearchParameters(sel=sel)
)

Dnew, Inew = index.search(xq, k, params=params)
np.testing.assert_array_equal(Iref, Inew)
Expand Down Expand Up @@ -317,10 +313,7 @@ def test_BinaryFlat_id_array(self):
self.do_test_id_selector("BinaryFlat", id_selector_type="array")

def test_BinaryFlat_no_heap(self):
params = faiss.SearchParameters()
params.use_heap = False
self.do_test_id_selector("BinaryFlat", params=params)

self.do_test_id_selector("BinaryFlat", use_heap=False)

class TestSearchParams(unittest.TestCase):

Expand Down

0 comments on commit 539dfae

Please sign in to comment.