From 2717b86da5be3b4b717f0a037601b38bef57a6d6 Mon Sep 17 00:00:00 2001 From: Gustav von Zitzewitz Date: Thu, 12 Dec 2024 09:21:15 +0100 Subject: [PATCH] add missing sel to hamming.h, add no heap test case, simplify valid_counter --- .../approx_topk_hamming/approx_topk_hamming.h | 6 ++--- faiss/utils/hamming.h | 9 ++++--- tests/test_search_params.py | 25 ++++++++++++++----- 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/faiss/utils/approx_topk_hamming/approx_topk_hamming.h b/faiss/utils/approx_topk_hamming/approx_topk_hamming.h index 5ee2930a01..4efd24d7c1 100644 --- a/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +++ b/faiss/utils/approx_topk_hamming/approx_topk_hamming.h @@ -99,19 +99,19 @@ struct HeapWithBucketsForHamming32< for (uint32_t ip = 0; ip < nb; ip += NBUCKETS) { for (uint32_t j = 0; j < NBUCKETS_8; j++) { uint32_t hamming_distances[8]; - uint32_t valid_mask = 0; + uint8_t valid_counter = 0; for (size_t j8 = 0; j8 < 8; j8++) { const uint32_t idx = j8 + j * 8 + ip + n_per_beam * beam_index; if (!sel || sel->is_member(idx)) { hamming_distances[j8] = hc.hamming( binary_vectors + idx * code_size); - valid_mask |= (1 << j8); + valid_counter++; } else { hamming_distances[j8] = std::numeric_limits::max(); } } - if (valid_mask == 0) { + if (valid_counter == 0) { continue; // Skip if all vectors are filtered out } diff --git a/faiss/utils/hamming.h b/faiss/utils/hamming.h index 85f9730e5c..4d72218c70 100644 --- a/faiss/utils/hamming.h +++ b/faiss/utils/hamming.h @@ -135,7 +135,8 @@ void hammings_knn_hc( size_t nb, size_t ncodes, int ordered, - ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK); + ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK, + const IDSelector* sel = nullptr); /* Legacy alias to hammings_knn_hc. */ void hammings_knn( @@ -166,7 +167,8 @@ void hammings_knn_mc( size_t k, size_t ncodes, int32_t* distances, - int64_t* labels); + int64_t* labels, + const IDSelector* sel = nullptr); /** same as hammings_knn except we are doing a range search with radius */ void hamming_range_search( @@ -176,7 +178,8 @@ void hamming_range_search( size_t nb, int radius, size_t ncodes, - RangeSearchResult* result); + RangeSearchResult* result, + const IDSelector* sel = nullptr); /* Counting the number of matches or of cross-matches (without returning them) For use with function that assume pre-allocated memory */ diff --git a/tests/test_search_params.py b/tests/test_search_params.py index 4337a70743..74e843bbbf 100644 --- a/tests/test_search_params.py +++ b/tests/test_search_params.py @@ -22,10 +22,11 @@ class TestSelector(unittest.TestCase): combinations as possible. """ - def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METRIC_L2, k=10): + def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METRIC_L2, k=10, params=None): """ Verify that the id selector returns the subset of results that are members according to the IDSelector. Supports id_selector_type="batch", "bitmap", "range", "range_sorted", "and", "or", "xor" + params: optional SearchParameters object to override default settings """ d = 32 # make sure dimension is multiple of 8 for binary ds = datasets.SyntheticDataset(d, 1000, 100, 20) @@ -73,6 +74,8 @@ def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METR subset = rs.choice(ds.nb, 50, replace=False).astype('int64') index.add(xb[subset]) + if "IVF" in index_key and id_selector_type == "range_sorted": + self.assertTrue(index.check_ids_sorted()) Dref, Iref0 = index.search(xq, k) Iref = subset[Iref0] Iref[Iref0 < 0] = -1 @@ -134,11 +137,16 @@ def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METR else: sel = faiss.IDSelectorBatch(subset) - params = ( - faiss.SearchParametersIVF(sel=sel) if "IVF" in index_key else - faiss.SearchParametersPQ(sel=sel) if "PQ" in index_key else - faiss.SearchParameters(sel=sel) - ) + if params is None: + params = ( + faiss.SearchParametersIVF(sel=sel) if "IVF" in index_key else + faiss.SearchParametersPQ(sel=sel) if "PQ" in index_key else + faiss.SearchParameters(sel=sel) + ) + else: + # Use provided params but ensure selector is set + params.sel = sel + Dnew, Inew = index.search(xq, k, params=params) np.testing.assert_array_equal(Iref, Inew) np.testing.assert_almost_equal(Dref, Dnew, decimal=5) @@ -308,6 +316,11 @@ def test_BinaryFlat_id_range(self): def test_BinaryFlat_id_array(self): self.do_test_id_selector("BinaryFlat", id_selector_type="array") + def test_BinaryFlat_no_heap(self): + params = faiss.SearchParameters() + params.use_heap = False + self.do_test_id_selector("BinaryFlat", params=params) + class TestSearchParams(unittest.TestCase):