Skip to content

Commit

Permalink
add missing sel to hamming.h, add no heap test case, simplify valid_c…
Browse files Browse the repository at this point in the history
…ounter
  • Loading branch information
gustavz committed Dec 12, 2024
1 parent c5c9cab commit 2717b86
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 12 deletions.
6 changes: 3 additions & 3 deletions faiss/utils/approx_topk_hamming/approx_topk_hamming.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,19 +99,19 @@ struct HeapWithBucketsForHamming32<
for (uint32_t ip = 0; ip < nb; ip += NBUCKETS) {
for (uint32_t j = 0; j < NBUCKETS_8; j++) {
uint32_t hamming_distances[8];
uint32_t valid_mask = 0;
uint8_t valid_counter = 0;
for (size_t j8 = 0; j8 < 8; j8++) {
const uint32_t idx = j8 + j * 8 + ip + n_per_beam * beam_index;
if (!sel || sel->is_member(idx)) {
hamming_distances[j8] = hc.hamming(
binary_vectors + idx * code_size);
valid_mask |= (1 << j8);
valid_counter++;
} else {
hamming_distances[j8] = std::numeric_limits<int32_t>::max();
}
}

if (valid_mask == 0) {
if (valid_counter == 0) {
continue; // Skip if all vectors are filtered out
}

Expand Down
9 changes: 6 additions & 3 deletions faiss/utils/hamming.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ void hammings_knn_hc(
size_t nb,
size_t ncodes,
int ordered,
ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK);
ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK,
const IDSelector* sel = nullptr);

/* Legacy alias to hammings_knn_hc. */
void hammings_knn(
Expand Down Expand Up @@ -166,7 +167,8 @@ void hammings_knn_mc(
size_t k,
size_t ncodes,
int32_t* distances,
int64_t* labels);
int64_t* labels,
const IDSelector* sel = nullptr);

/** same as hammings_knn except we are doing a range search with radius */
void hamming_range_search(
Expand All @@ -176,7 +178,8 @@ void hamming_range_search(
size_t nb,
int radius,
size_t ncodes,
RangeSearchResult* result);
RangeSearchResult* result,
const IDSelector* sel = nullptr);

/* Counting the number of matches or of cross-matches (without returning them)
For use with function that assume pre-allocated memory */
Expand Down
25 changes: 19 additions & 6 deletions tests/test_search_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ class TestSelector(unittest.TestCase):
combinations as possible.
"""

def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METRIC_L2, k=10):
def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METRIC_L2, k=10, params=None):
""" Verify that the id selector returns the subset of results that are
members according to the IDSelector.
Supports id_selector_type="batch", "bitmap", "range", "range_sorted", "and", "or", "xor"
params: optional SearchParameters object to override default settings
"""
d = 32 # make sure dimension is multiple of 8 for binary
ds = datasets.SyntheticDataset(d, 1000, 100, 20)
Expand Down Expand Up @@ -73,6 +74,8 @@ def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METR
subset = rs.choice(ds.nb, 50, replace=False).astype('int64')

index.add(xb[subset])
if "IVF" in index_key and id_selector_type == "range_sorted":
self.assertTrue(index.check_ids_sorted())
Dref, Iref0 = index.search(xq, k)
Iref = subset[Iref0]
Iref[Iref0 < 0] = -1
Expand Down Expand Up @@ -134,11 +137,16 @@ def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METR
else:
sel = faiss.IDSelectorBatch(subset)

params = (
faiss.SearchParametersIVF(sel=sel) if "IVF" in index_key else
faiss.SearchParametersPQ(sel=sel) if "PQ" in index_key else
faiss.SearchParameters(sel=sel)
)
if params is None:
params = (
faiss.SearchParametersIVF(sel=sel) if "IVF" in index_key else
faiss.SearchParametersPQ(sel=sel) if "PQ" in index_key else
faiss.SearchParameters(sel=sel)
)
else:
# Use provided params but ensure selector is set
params.sel = sel

Dnew, Inew = index.search(xq, k, params=params)
np.testing.assert_array_equal(Iref, Inew)
np.testing.assert_almost_equal(Dref, Dnew, decimal=5)
Expand Down Expand Up @@ -308,6 +316,11 @@ def test_BinaryFlat_id_range(self):
def test_BinaryFlat_id_array(self):
self.do_test_id_selector("BinaryFlat", id_selector_type="array")

def test_BinaryFlat_no_heap(self):
params = faiss.SearchParameters()
params.use_heap = False
self.do_test_id_selector("BinaryFlat", params=params)


class TestSearchParams(unittest.TestCase):

Expand Down

0 comments on commit 2717b86

Please sign in to comment.