Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SearchParameters support for IndexBinaryFlat #4055

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions faiss/IndexBinaryFlat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ void IndexBinaryFlat::search(
int32_t* distances,
idx_t* labels,
const SearchParameters* params) const {
FAISS_THROW_IF_NOT_MSG(
!params, "search params not supported for this index");
// Extract IDSelector from params if present
const IDSelector* sel = params ? params->sel : nullptr;
FAISS_THROW_IF_NOT(k > 0);

const idx_t block_size = query_batch_size;
Expand All @@ -60,7 +60,8 @@ void IndexBinaryFlat::search(
ntotal,
code_size,
/* ordered = */ true,
approx_topk_mode);
approx_topk_mode,
sel);
} else {
hammings_knn_mc(
x + s * code_size,
Expand All @@ -70,7 +71,8 @@ void IndexBinaryFlat::search(
k,
code_size,
distances + s * k,
labels + s * k);
labels + s * k,
sel);
}
}
}
Expand Down Expand Up @@ -107,9 +109,9 @@ void IndexBinaryFlat::range_search(
int radius,
RangeSearchResult* result,
const SearchParameters* params) const {
FAISS_THROW_IF_NOT_MSG(
!params, "search params not supported for this index");
hamming_range_search(x, xb.data(), n, ntotal, radius, code_size, result);
const IDSelector* sel = params ? params->sel : nullptr;
hamming_range_search(
x, xb.data(), n, ntotal, radius, code_size, result, sel);
}

} // namespace faiss
42 changes: 30 additions & 12 deletions faiss/utils/approx_topk_hamming/approx_topk_hamming.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ struct HeapWithBucketsForHamming32<
// output distances
int* const __restrict bh_val,
// output indices, each being within [0, n) range
int64_t* const __restrict bh_ids) {
int64_t* const __restrict bh_ids,
// optional id selector for filtering
const IDSelector* sel = nullptr) {
// forward a call to bs_addn with 1 beam
bs_addn(1, n, hc, binaryVectors, k, bh_val, bh_ids);
bs_addn(1, n, hc, binaryVectors, k, bh_val, bh_ids, sel);
}

static void bs_addn(
Expand All @@ -66,7 +68,9 @@ struct HeapWithBucketsForHamming32<
int* const __restrict bh_val,
// output indices, each being within [0, n_per_beam * beam_size)
// range
int64_t* const __restrict bh_ids) {
int64_t* const __restrict bh_ids,
// optional id selector for filtering
const IDSelector* sel = nullptr) {
//
using C = CMax<int, int64_t>;

Expand Down Expand Up @@ -95,11 +99,22 @@ struct HeapWithBucketsForHamming32<
for (uint32_t ip = 0; ip < nb; ip += NBUCKETS) {
for (uint32_t j = 0; j < NBUCKETS_8; j++) {
uint32_t hamming_distances[8];
uint8_t valid_counter = 0;
for (size_t j8 = 0; j8 < 8; j8++) {
hamming_distances[j8] = hc.hamming(
binary_vectors +
(j8 + j * 8 + ip + n_per_beam * beam_index) *
code_size);
const uint32_t idx =
j8 + j * 8 + ip + n_per_beam * beam_index;
if (!sel || sel->is_member(idx)) {
hamming_distances[j8] = hc.hamming(
binary_vectors + idx * code_size);
valid_counter++;
} else {
hamming_distances[j8] =
std::numeric_limits<int32_t>::max();
}
}

if (valid_counter == 8) {
continue; // Skip if all vectors are filtered out
}

// loop. Compiler should get rid of unneeded ops
Expand Down Expand Up @@ -157,7 +172,8 @@ struct HeapWithBucketsForHamming32<
const auto value = min_distances_scalar[j8];
const auto index = min_indices_scalar[j8];

if (C::cmp2(bh_val[0], value, bh_ids[0], index)) {
if (value < std::numeric_limits<int32_t>::max() &&
C::cmp2(bh_val[0], value, bh_ids[0], index)) {
heap_replace_top<C>(
k, bh_val, bh_ids, value, index);
}
Expand All @@ -168,11 +184,13 @@ struct HeapWithBucketsForHamming32<
// process leftovers
for (uint32_t ip = nb; ip < n_per_beam; ip++) {
const auto index = ip + n_per_beam * beam_index;
const auto value =
hc.hamming(binary_vectors + (index)*code_size);
if (!sel || sel->is_member(index)) {
const auto value =
hc.hamming(binary_vectors + (index)*code_size);

if (C::cmp(bh_val[0], value)) {
heap_replace_top<C>(k, bh_val, bh_ids, value, index);
if (C::cmp(bh_val[0], value)) {
heap_replace_top<C>(k, bh_val, bh_ids, value, index);
}
}
}
}
Expand Down
52 changes: 38 additions & 14 deletions faiss/utils/hamming.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/IDSelector.h>
#include <faiss/utils/Heap.h>
#include <faiss/utils/approx_topk_hamming/approx_topk_hamming.h>
#include <faiss/utils/utils.h>
Expand Down Expand Up @@ -171,7 +172,8 @@ void hammings_knn_hc(
size_t n2,
bool order = true,
bool init_heap = true,
ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK) {
ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK,
const faiss::IDSelector* sel = nullptr) {
size_t k = ha->k;
if (init_heap)
ha->heapify();
Expand Down Expand Up @@ -204,7 +206,7 @@ void hammings_knn_hc(
NB, \
BD, \
HammingComputer>:: \
addn(j1 - j0, hc, bs2_, k, bh_val_, bh_ids_); \
addn(j1 - j0, hc, bs2_, k, bh_val_, bh_ids_, sel); \
break;

switch (approx_topk_mode) {
Expand All @@ -214,6 +216,9 @@ void hammings_knn_hc(
HANDLE_APPROX(32, 2)
default: {
for (size_t j = j0; j < j1; j++, bs2_ += bytes_per_code) {
if (sel && !sel->is_member(j)) {
continue;
}
dis = hc.hamming(bs2_);
if (dis < bh_val_[0]) {
faiss::maxheap_replace_top<hamdis_t>(
Expand All @@ -238,7 +243,8 @@ void hammings_knn_mc(
size_t nb,
size_t k,
int32_t* __restrict distances,
int64_t* __restrict labels) {
int64_t* __restrict labels,
const faiss::IDSelector* sel) {
const int nBuckets = bytes_per_code * 8 + 1;
std::vector<int> all_counters(na * nBuckets, 0);
std::unique_ptr<int64_t[]> all_ids_per_dis(new int64_t[na * nBuckets * k]);
Expand All @@ -259,7 +265,9 @@ void hammings_knn_mc(
#pragma omp parallel for
for (int64_t i = 0; i < na; ++i) {
for (size_t j = j0; j < j1; ++j) {
cs[i].update_counter(b + j * bytes_per_code, j);
if (!sel || sel->is_member(j)) {
cs[i].update_counter(b + j * bytes_per_code, j);
}
}
}
}
Expand Down Expand Up @@ -291,7 +299,8 @@ void hamming_range_search(
size_t nb,
int radius,
size_t code_size,
RangeSearchResult* res) {
RangeSearchResult* res,
const faiss::IDSelector* sel) {
#pragma omp parallel
{
RangeSearchPartialResult pres(res);
Expand All @@ -303,9 +312,11 @@ void hamming_range_search(
RangeQueryResult& qres = pres.new_result(i);

for (size_t j = 0; j < nb; j++) {
int dis = hc.hamming(yi);
if (dis < radius) {
qres.add(dis, j);
if (!sel || sel->is_member(j)) {
int dis = hc.hamming(yi);
if (dis < radius) {
qres.add(dis, j);
}
}
yi += code_size;
}
Expand Down Expand Up @@ -489,10 +500,21 @@ void hammings_knn_hc(
size_t nb,
size_t ncodes,
int order,
ApproxTopK_mode_t approx_topk_mode) {
ApproxTopK_mode_t approx_topk_mode,
const faiss::IDSelector* sel) {
Run_hammings_knn_hc r;
dispatch_HammingComputer(
ncodes, r, ncodes, ha, a, b, nb, order, true, approx_topk_mode);
ncodes,
r,
ncodes,
ha,
a,
b,
nb,
order,
true,
approx_topk_mode,
sel);
}

void hammings_knn_mc(
Expand All @@ -503,10 +525,11 @@ void hammings_knn_mc(
size_t k,
size_t ncodes,
int32_t* __restrict distances,
int64_t* __restrict labels) {
int64_t* __restrict labels,
const faiss::IDSelector* sel) {
Run_hammings_knn_mc r;
dispatch_HammingComputer(
ncodes, r, ncodes, a, b, na, nb, k, distances, labels);
ncodes, r, ncodes, a, b, na, nb, k, distances, labels, sel);
}

void hamming_range_search(
Expand All @@ -516,10 +539,11 @@ void hamming_range_search(
size_t nb,
int radius,
size_t code_size,
RangeSearchResult* result) {
RangeSearchResult* result,
const faiss::IDSelector* sel) {
Run_hamming_range_search r;
dispatch_HammingComputer(
code_size, r, a, b, na, nb, radius, code_size, result);
code_size, r, a, b, na, nb, radius, code_size, result, sel);
}

/* Count number of matches given a max threshold */
Expand Down
10 changes: 7 additions & 3 deletions faiss/utils/hamming.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <stdint.h>

#include <faiss/impl/IDSelector.h>
#include <faiss/impl/platform_macros.h>
#include <faiss/utils/Heap.h>

Expand Down Expand Up @@ -135,7 +136,8 @@ void hammings_knn_hc(
size_t nb,
size_t ncodes,
int ordered,
ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK);
ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK,
const IDSelector* sel = nullptr);
gustavz marked this conversation as resolved.
Show resolved Hide resolved

/* Legacy alias to hammings_knn_hc. */
void hammings_knn(
Expand Down Expand Up @@ -166,7 +168,8 @@ void hammings_knn_mc(
size_t k,
size_t ncodes,
int32_t* distances,
int64_t* labels);
int64_t* labels,
const IDSelector* sel = nullptr);

/** same as hammings_knn except we are doing a range search with radius */
void hamming_range_search(
Expand All @@ -176,7 +179,8 @@ void hamming_range_search(
size_t nb,
int radius,
size_t ncodes,
RangeSearchResult* result);
RangeSearchResult* result,
const IDSelector* sel = nullptr);

/* Counting the number of matches or of cross-matches (without returning them)
For use with function that assume pre-allocated memory */
Expand Down
Loading
Loading