diff --git a/CMakeLists.txt b/CMakeLists.txt index c4fea46a3d..571aa965ee 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -119,3 +119,5 @@ if(BUILD_TESTING) endif() endif() endif() + +add_subdirectory(faiss/cppcontrib/knowhere) diff --git a/benchs/CMakeLists.txt b/benchs/CMakeLists.txt index 46c81ae248..88e13be322 100644 --- a/benchs/CMakeLists.txt +++ b/benchs/CMakeLists.txt @@ -8,4 +8,3 @@ add_executable(bench_ivf_selector EXCLUDE_FROM_ALL bench_ivf_selector.cpp) target_link_libraries(bench_ivf_selector PRIVATE faiss) - diff --git a/faiss/cppcontrib/knowhere/CMakeLists.txt b/faiss/cppcontrib/knowhere/CMakeLists.txt new file mode 100644 index 0000000000..75b29ed700 --- /dev/null +++ b/faiss/cppcontrib/knowhere/CMakeLists.txt @@ -0,0 +1,29 @@ +cmake_minimum_required(VERSION 3.24.0) + +project(faiss_cppcontrib_knowhere) + +find_package(OpenMP REQUIRED) + +set(SRC_FAISS_CPPCONTRIB_KNOWHERE + IndexBruteForceWrapper.cpp + IndexHNSWWrapper.cpp + IndexWrapper.cpp +) + +include(${PROJECT_SOURCE_DIR}/../../../cmake/link_to_faiss_lib.cmake) + +add_library(faiss_cppcontrib_knowhere ${SRC_FAISS_CPPCONTRIB_KNOWHERE}) +link_to_faiss_lib(faiss_cppcontrib_knowhere) + +target_include_directories(faiss_cppcontrib_knowhere PRIVATE + ${PROJECT_SOURCE_DIR}/../../.. +) + +add_executable(bench_hnsw_knowhere benchs/bench_hnsw_knowhere.cpp) + +link_to_faiss_lib(bench_hnsw_knowhere) +target_link_libraries(bench_hnsw_knowhere PRIVATE faiss_cppcontrib_knowhere OpenMP::OpenMP_CXX) + +target_include_directories(bench_hnsw_knowhere PRIVATE + ${PROJECT_SOURCE_DIR}/../../.. +) diff --git a/faiss/cppcontrib/knowhere/IndexBruteForceWrapper.cpp b/faiss/cppcontrib/knowhere/IndexBruteForceWrapper.cpp new file mode 100644 index 0000000000..420677fcdc --- /dev/null +++ b/faiss/cppcontrib/knowhere/IndexBruteForceWrapper.cpp @@ -0,0 +1,223 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace faiss { +namespace cppcontrib { +namespace knowhere { + +IndexBruteForceWrapper::IndexBruteForceWrapper(Index* underlying_index) + : IndexWrapper{underlying_index} {} + +void IndexBruteForceWrapper::search( + idx_t n, + const float* __restrict x, + idx_t k, + float* __restrict distances, + idx_t* __restrict labels, + const SearchParameters* __restrict params) const { + FAISS_THROW_IF_NOT(k > 0); + + idx_t check_period = + InterruptCallback::get_period_hint(index->d * index->ntotal); + + for (idx_t i0 = 0; i0 < n; i0 += check_period) { + idx_t i1 = std::min(i0 + check_period, n); + +#pragma omp parallel if (i1 - i0 > 1) + { + std::unique_ptr dis( + index->get_distance_computer()); + +#pragma omp for schedule(guided) + for (idx_t i = i0; i < i1; i++) { + // prepare the query + dis->set_query(x + i * index->d); + + // allocate heap + idx_t* const __restrict local_ids = labels + i * index->d; + float* const __restrict local_distances = + distances + i * index->d; + + // set up a filter + IDSelector* __restrict sel = + (params == nullptr) ? nullptr : params->sel; + + if (is_similarity_metric(index->metric_type)) { + using C = CMin; + + if (sel == nullptr) { + // Compiler is expected to de-virtualize virtual method + // calls + IDSelectorAll sel_all; + brute_force_search_impl< + C, + DistanceComputer, + IDSelectorAll>( + index->ntotal, + *dis, + sel_all, + k, + local_distances, + local_ids); + } else { + brute_force_search_impl< + C, + DistanceComputer, + IDSelector>( + index->ntotal, + *dis, + *sel, + k, + local_distances, + local_ids); + } + } else { + using C = CMax; + + if (sel == nullptr) { + // Compiler is expected to de-virtualize virtual method + // calls + IDSelectorAll sel_all; + brute_force_search_impl< + C, + DistanceComputer, + IDSelectorAll>( + index->ntotal, + *dis, + sel_all, + k, + local_distances, + local_ids); + } else { + brute_force_search_impl< + C, + DistanceComputer, + IDSelector>( + index->ntotal, + *dis, + *sel, + k, + local_distances, + local_ids); + } + } + } + } + + InterruptCallback::check(); + } +} + +void IndexBruteForceWrapper::range_search( + idx_t n, + const float* __restrict x, + float radius, + RangeSearchResult* __restrict result, + const SearchParameters* __restrict params) const { + using RH_min = RangeSearchBlockResultHandler>; + using RH_max = RangeSearchBlockResultHandler>; + RH_min bres_min(result, radius); + RH_max bres_max(result, radius); + + idx_t check_period = + InterruptCallback::get_period_hint(index->d * index->ntotal); + + for (idx_t i0 = 0; i0 < n; i0 += check_period) { + idx_t i1 = std::min(i0 + check_period, n); + +#pragma omp parallel if (i1 - i0 > 1) + { + std::unique_ptr dis( + index->get_distance_computer()); + + typename RH_min::SingleResultHandler res_min(bres_min); + typename RH_max::SingleResultHandler res_max(bres_max); + +#pragma omp for schedule(guided) + for (idx_t i = i0; i < i1; i++) { + // prepare the query + dis->set_query(x + i * index->d); + + // set up a filter + IDSelector* __restrict sel = + (params == nullptr) ? nullptr : params->sel; + + if (is_similarity_metric(index->metric_type)) { + res_max.begin(i); + + if (sel == nullptr) { + // Compiler is expected to de-virtualize virtual method + // calls + IDSelectorAll sel_all; + + brute_force_range_search_impl< + typename RH_max::SingleResultHandler, + DistanceComputer, + IDSelectorAll>( + index->ntotal, *dis, sel_all, res_max); + } else { + brute_force_range_search_impl< + typename RH_max::SingleResultHandler, + DistanceComputer, + IDSelector>(index->ntotal, *dis, *sel, res_max); + } + + res_max.end(); + } else { + res_min.begin(i); + + if (sel == nullptr) { + // Compiler is expected to de-virtualize virtual method + // calls + IDSelectorAll sel_all; + + brute_force_range_search_impl< + typename RH_min::SingleResultHandler, + DistanceComputer, + IDSelectorAll>( + index->ntotal, *dis, sel_all, res_min); + } else { + brute_force_range_search_impl< + typename RH_min::SingleResultHandler, + DistanceComputer, + IDSelector>(index->ntotal, *dis, *sel, res_min); + } + + res_min.end(); + } + } + } + + InterruptCallback::check(); + } +} + +} // namespace knowhere +} // namespace cppcontrib +} // namespace faiss diff --git a/faiss/cppcontrib/knowhere/IndexBruteForceWrapper.h b/faiss/cppcontrib/knowhere/IndexBruteForceWrapper.h new file mode 100644 index 0000000000..ef9194adeb --- /dev/null +++ b/faiss/cppcontrib/knowhere/IndexBruteForceWrapper.h @@ -0,0 +1,49 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#pragma once + +#include + +#include + +namespace faiss { +namespace cppcontrib { +namespace knowhere { + +// override a search procedure to perform a brute-force search. +struct IndexBruteForceWrapper : IndexWrapper { + IndexBruteForceWrapper(Index* underlying_index); + + /// entry point for search + void search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + const SearchParameters* params) const override; + + /// entry point for range search + void range_search( + idx_t n, + const float* x, + float radius, + RangeSearchResult* result, + const SearchParameters* params) const override; +}; + +} // namespace knowhere +} // namespace cppcontrib +} // namespace faiss diff --git a/faiss/cppcontrib/knowhere/IndexHNSWWrapper.cpp b/faiss/cppcontrib/knowhere/IndexHNSWWrapper.cpp new file mode 100644 index 0000000000..2161a470d3 --- /dev/null +++ b/faiss/cppcontrib/knowhere/IndexHNSWWrapper.cpp @@ -0,0 +1,382 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace faiss { +namespace cppcontrib { +namespace knowhere { + +// a visitor that does nothing +struct DummyVisitor { + using storage_idx_t = HNSW::storage_idx_t; + + void visit_level(const int level) { + // does nothing + } + + void visit_edge( + const int level, + const storage_idx_t node_from, + const storage_idx_t node_to, + const float distance) { + // does nothing + } +}; + +/************************************************************** + * Utilities + **************************************************************/ + +namespace { + +// cloned from IndexHNSW.cpp +DistanceComputer* storage_distance_computer(const Index* storage) { + if (is_similarity_metric(storage->metric_type)) { + return new NegativeDistanceComputer(storage->get_distance_computer()); + } else { + return storage->get_distance_computer(); + } +} + +} // namespace + +/************************************************************** + * IndexHNSWWrapper implementation + **************************************************************/ + +IndexHNSWWrapper::IndexHNSWWrapper(IndexHNSW* underlying_index) + : IndexWrapper(underlying_index) {} + +void IndexHNSWWrapper::search( + idx_t n, + const float* __restrict x, + idx_t k, + float* __restrict distances, + idx_t* __restrict labels, + const SearchParameters* params_in) const { + FAISS_THROW_IF_NOT(k > 0); + + const IndexHNSW* index_hnsw = dynamic_cast(index); + FAISS_THROW_IF_NOT(index_hnsw); + + FAISS_THROW_IF_NOT_MSG(index_hnsw->storage, "No storage index"); + + // set up + using C = HNSW::C; + + // check if the graph is empty + if (index_hnsw->hnsw.entry_point == -1) { + for (idx_t i = 0; i < k * n; i++) { + distances[i] = C::neutral(); + labels[i] = -1; + } + + return; + } + + // check parameters + const SearchParametersHNSWWrapper* params = nullptr; + const HNSW& hnsw = index_hnsw->hnsw; + + float kAlpha = 0.0f; + int efSearch = hnsw.efSearch; + if (params_in) { + params = dynamic_cast(params_in); + FAISS_THROW_IF_NOT_MSG(params, "params type invalid"); + efSearch = params->efSearch; + kAlpha = params->kAlpha; + } + + // set up hnsw_stats + HNSWStats* __restrict const hnsw_stats = + (params == nullptr) ? nullptr : params->hnsw_stats; + + // + size_t n1 = 0; + size_t n2 = 0; + size_t ndis = 0; + size_t nhops = 0; + + idx_t check_period = InterruptCallback::get_period_hint( + hnsw.max_level * index->d * efSearch); + + for (idx_t i0 = 0; i0 < n; i0 += check_period) { + idx_t i1 = std::min(i0 + check_period, n); + +#pragma omp parallel if (i1 - i0 > 1) + { + Bitset bitset_visited_nodes = + Bitset::create_uninitialized(index->ntotal); + + // create a distance computer + std::unique_ptr dis( + storage_distance_computer(index_hnsw->storage)); + +#pragma omp for reduction(+ : n1, n2, ndis, nhops) schedule(guided) + for (idx_t i = i0; i < i1; i++) { + // prepare the query + dis->set_query(x + i * index->d); + + // prepare the table of visited elements + bitset_visited_nodes.clear(); + + // a visitor + DummyVisitor graph_visitor; + + // future results + HNSWStats local_stats; + + // set up a filter + IDSelector* sel = (params == nullptr) ? nullptr : params->sel; + if (sel == nullptr) { + // no filter. + // It it expected that a compile will be able to + // de-virtualize the class. + IDSelectorAll sel_all; + + using searcher_type = v2_hnsw_searcher< + DistanceComputer, + DummyVisitor, + Bitset, + IDSelectorAll>; + + searcher_type searcher{ + hnsw, + *(dis.get()), + graph_visitor, + bitset_visited_nodes, + sel_all, + kAlpha, + params}; + + local_stats = searcher.search( + k, distances + i * k, labels + i * k); + } else { + // there is a filter + using searcher_type = v2_hnsw_searcher< + DistanceComputer, + DummyVisitor, + Bitset, + IDSelector>; + + searcher_type searcher{ + hnsw, + *(dis.get()), + graph_visitor, + bitset_visited_nodes, + *sel, + kAlpha, + params}; + + local_stats = searcher.search( + k, distances + i * k, labels + i * k); + } + + // update stats if possible + if (hnsw_stats != nullptr) { + n1 += local_stats.n1; + n2 += local_stats.n2; + ndis += local_stats.ndis; + nhops += local_stats.nhops; + } + } + } + + InterruptCallback::check(); + } + + // update stats if possible + if (hnsw_stats != nullptr) { + hnsw_stats->combine({n1, n2, ndis, nhops}); + } + + // done, update the results, if needed + if (is_similarity_metric(index->metric_type)) { + // we need to revert the negated distances + for (idx_t i = 0; i < k * n; i++) { + distances[i] = -distances[i]; + } + } +} + +void IndexHNSWWrapper::range_search( + idx_t n, + const float* x, + float radius_in, + RangeSearchResult* result, + const SearchParameters* params_in) const { + const IndexHNSW* index_hnsw = dynamic_cast(index); + FAISS_THROW_IF_NOT(index_hnsw); + + FAISS_THROW_IF_NOT_MSG(index_hnsw->storage, "No storage index"); + + // check if the graph is empty + if (index_hnsw->hnsw.entry_point == -1) { + return; + } + + // check parameters + const SearchParametersHNSWWrapper* params = nullptr; + const HNSW& hnsw = index_hnsw->hnsw; + + float kAlpha = 0.0f; + int efSearch = hnsw.efSearch; + if (params_in) { + params = dynamic_cast(params_in); + FAISS_THROW_IF_NOT_MSG(params, "params type invalid"); + + kAlpha = params->kAlpha; + efSearch = params->efSearch; + } + + // set up hnsw_stats + HNSWStats* __restrict const hnsw_stats = + (params == nullptr) ? nullptr : params->hnsw_stats; + + // + size_t n1 = 0; + size_t n2 = 0; + size_t ndis = 0; + size_t nhops = 0; + + // radius + float radius = radius_in; + if (is_similarity_metric(this->metric_type)) { + radius *= (-1); + } + + // initialize a ResultHandler + using RH_min = RangeSearchBlockResultHandler>; + RH_min bres_min(result, radius); + + // no parallelism by design + idx_t check_period = InterruptCallback::get_period_hint( + hnsw.max_level * index->d * efSearch); + + for (idx_t i0 = 0; i0 < n; i0 += check_period) { + idx_t i1 = std::min(i0 + check_period, n); + +#pragma omp parallel if (i1 - i0 > 1) + { + // + Bitset bitset_visited_nodes = + Bitset::create_uninitialized(index->ntotal); + + // create a distance computer + std::unique_ptr dis( + storage_distance_computer(index_hnsw->storage)); + +#pragma omp for reduction(+ : n1, n2, ndis, nhops) schedule(guided) + for (idx_t i = i0; i < i1; i++) { + typename RH_min::SingleResultHandler res_min(bres_min); + res_min.begin(i); + + // prepare the query + dis->set_query(x + i * index->d); + + // prepare the table of visited elements + bitset_visited_nodes.clear(); + + // future results + HNSWStats local_stats; + + // set up a filter + IDSelector* sel = (params == nullptr) ? nullptr : params->sel; + + if (sel == nullptr) { + IDSelectorAll sel_all; + DummyVisitor graph_visitor; + + using searcher_type = v2_hnsw_searcher< + DistanceComputer, + DummyVisitor, + Bitset, + IDSelectorAll>; + + searcher_type searcher( + hnsw, + *(dis.get()), + graph_visitor, + bitset_visited_nodes, + sel_all, + kAlpha, + params); + + local_stats = searcher.range_search(radius, &res_min); + } else { + DummyVisitor graph_visitor; + + using searcher_type = v2_hnsw_searcher< + DistanceComputer, + DummyVisitor, + Bitset, + IDSelector>; + + searcher_type searcher{ + hnsw, + *(dis.get()), + graph_visitor, + bitset_visited_nodes, + *sel, + kAlpha, + params}; + + local_stats = searcher.range_search(radius, &res_min); + } + + // update stats if possible + if (hnsw_stats != nullptr) { + n1 += local_stats.n1; + n2 += local_stats.n2; + ndis += local_stats.ndis; + nhops += local_stats.nhops; + } + + // + res_min.end(); + } + } + } + + // update stats if possible + if (hnsw_stats != nullptr) { + hnsw_stats->combine({n1, n2, ndis, nhops}); + } + + // done, update the results, if needed + if (is_similarity_metric(this->metric_type)) { + // we need to revert the negated distances + for (size_t i = 0; i < result->lims[result->nq]; i++) { + result->distances[i] = -result->distances[i]; + } + } +} + +} // namespace knowhere +} // namespace cppcontrib +} // namespace faiss diff --git a/faiss/cppcontrib/knowhere/IndexHNSWWrapper.h b/faiss/cppcontrib/knowhere/IndexHNSWWrapper.h new file mode 100644 index 0000000000..e80495e47f --- /dev/null +++ b/faiss/cppcontrib/knowhere/IndexHNSWWrapper.h @@ -0,0 +1,66 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#pragma once + +#include +#include + +#include + +#include + +namespace faiss { +namespace cppcontrib { +namespace knowhere { + +// Custom parameters for IndexHNSW. +struct SearchParametersHNSWWrapper : public SearchParametersHNSW { + // Stats will be updated if the object pointer is provided. + HNSWStats* hnsw_stats = nullptr; + // filtering parameter. A floating point value within [0.0f, 1.0f range] + float kAlpha = 0.0f; + + inline ~SearchParametersHNSWWrapper() {} +}; + +// TODO: +// Please note that this particular searcher is int32_t based, so won't +// work correctly for 2B+ samples. This can be easily changed, if needed. + +// override a search() procedure for IndexHNSW. +struct IndexHNSWWrapper : IndexWrapper { + IndexHNSWWrapper(IndexHNSW* underlying_index); + + /// entry point for search + void search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + const SearchParameters* params = nullptr) const override; + + /// entry point for range search + void range_search( + idx_t n, + const float* x, + float radius, + RangeSearchResult* result, + const SearchParameters* params) const override; +}; + +} // namespace knowhere +} // namespace cppcontrib +} // namespace faiss diff --git a/faiss/cppcontrib/knowhere/IndexWrapper.cpp b/faiss/cppcontrib/knowhere/IndexWrapper.cpp new file mode 100644 index 0000000000..3ad50976c5 --- /dev/null +++ b/faiss/cppcontrib/knowhere/IndexWrapper.cpp @@ -0,0 +1,76 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include + +namespace faiss { +namespace cppcontrib { +namespace knowhere { + +IndexWrapper::IndexWrapper(Index* underlying_index) + : Index{underlying_index->d, underlying_index->metric_type}, + index{underlying_index} { + ntotal = underlying_index->ntotal; + is_trained = underlying_index->is_trained; + verbose = underlying_index->verbose; + metric_arg = underlying_index->metric_arg; +} + +IndexWrapper::~IndexWrapper() {} + +void IndexWrapper::train(idx_t n, const float* x) { + index->train(n, x); + is_trained = index->is_trained; +} + +void IndexWrapper::add(idx_t n, const float* x) { + index->add(n, x); + this->ntotal = index->ntotal; +} + +void IndexWrapper::search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + const SearchParameters* params) const { + index->search(n, x, k, distances, labels, params); +} + +void IndexWrapper::range_search( + idx_t n, + const float* x, + float radius, + RangeSearchResult* result, + const SearchParameters* params) const { + index->range_search(n, x, radius, result, params); +} + +void IndexWrapper::reset() { + index->reset(); + this->ntotal = 0; +} + +void IndexWrapper::merge_from(Index& otherIndex, idx_t add_id) { + index->merge_from(otherIndex, add_id); +} + +DistanceComputer* IndexWrapper::get_distance_computer() const { + return index->get_distance_computer(); +} + +} // namespace knowhere +} // namespace cppcontrib +} // namespace faiss diff --git a/faiss/cppcontrib/knowhere/IndexWrapper.h b/faiss/cppcontrib/knowhere/IndexWrapper.h new file mode 100644 index 0000000000..2ba2859514 --- /dev/null +++ b/faiss/cppcontrib/knowhere/IndexWrapper.h @@ -0,0 +1,61 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#pragma once + +#include + +namespace faiss { +namespace cppcontrib { +namespace knowhere { + +// This index is useful for overriding a certain base functionality +// on the fly. +struct IndexWrapper : Index { + // a non-owning pointer + Index* index = nullptr; + + explicit IndexWrapper(Index* underlying_index); + + virtual ~IndexWrapper(); + + void train(idx_t n, const float* x) override; + + void add(idx_t n, const float* x) override; + + void search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + const SearchParameters* params = nullptr) const override; + + void range_search( + idx_t n, + const float* x, + float radius, + RangeSearchResult* result, + const SearchParameters* params = nullptr) const override; + + void reset() override; + + void merge_from(Index& otherIndex, idx_t add_id = 0) override; + + DistanceComputer* get_distance_computer() const override; +}; + +} // namespace knowhere +} // namespace cppcontrib +} // namespace faiss diff --git a/faiss/cppcontrib/knowhere/README.txt b/faiss/cppcontrib/knowhere/README.txt new file mode 100644 index 0000000000..f89ffcd0c5 --- /dev/null +++ b/faiss/cppcontrib/knowhere/README.txt @@ -0,0 +1 @@ +from https://github.com/zilliztech/knowhere diff --git a/faiss/cppcontrib/knowhere/benchs/bench_hnsw_knowhere.cpp b/faiss/cppcontrib/knowhere/benchs/bench_hnsw_knowhere.cpp new file mode 100644 index 0000000000..012908e0de --- /dev/null +++ b/faiss/cppcontrib/knowhere/benchs/bench_hnsw_knowhere.cpp @@ -0,0 +1,221 @@ +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +std::vector generate_dataset( + const size_t n, + const size_t d, + uint64_t seed) { + std::default_random_engine rng(seed); + std::uniform_real_distribution u(-1, 1); + + std::vector data(n * d); + for (size_t i = 0; i < data.size(); i++) { + data[i] = u(rng); + } + + return data; +} + +float get_recall_rate( + const size_t nq, + const size_t k, + const std::vector& baseline, + const std::vector& candidate) { + size_t n = 0; + for (size_t i = 0; i < nq; i++) { + std::unordered_set a_set(k * 4); + + for (size_t j = 0; j < k; j++) { + a_set.insert(baseline[i * k + j]); + } + + for (size_t j = 0; j < k; j++) { + auto itr = a_set.find(candidate[i * k + j]); + if (itr != a_set.cend()) { + n += 1; + } + } + } + + return (float)n / candidate.size(); +} + +struct StopWatch { + using timepoint_t = std::chrono::time_point; + timepoint_t Start; + + StopWatch() { + Start = std::chrono::steady_clock::now(); + } + + double elapsed() const { + const auto now = std::chrono::steady_clock::now(); + std::chrono::duration elapsed = now - Start; + return elapsed.count(); + } +}; + +void test(const size_t nt, const size_t d, const size_t nq, const size_t k) { + // generate a dataset for train + std::vector xt = generate_dataset(nt, d, 123); + + // create an baseline + std::unique_ptr baseline_index( + faiss::index_factory(d, "Flat", faiss::MetricType::METRIC_L2)); + baseline_index->train(nt, xt.data()); + baseline_index->add(nt, xt.data()); + + // create an hnsw index + std::unique_ptr hnsw_index(faiss::index_factory( + d, "HNSW32,Flat", faiss::MetricType::METRIC_L2)); + hnsw_index->train(nt, xt.data()); + hnsw_index->add(nt, xt.data()); + + // generate a query dataset + std::vector xq = generate_dataset(nq, d, 123); + + // a seed + std::default_random_engine rng(789); + + // print header + printf("d=%zd, nt=%zd, nq=%zd\n", d, nt, nq); + + // perform evaluation with a different level of filtering + for (const size_t percent : + {0, 1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 95, 99}) { + // generate a bitset with a given percentage + std::vector ids_to_use(nt); + std::iota(ids_to_use.begin(), ids_to_use.end(), 0); + + std::shuffle(ids_to_use.begin(), ids_to_use.end(), rng); + + // number of points to use + const size_t nt_real = + size_t(std::max(1.0, nt - (nt * percent / 100.0))); + + // create a bitset + faiss::cppcontrib::knowhere::Bitset bitset = + faiss::cppcontrib::knowhere::Bitset::create_cleared(nt); + for (size_t i = 0; i < nt_real; i++) { + bitset.set(ids_to_use[i]); + } + + // create an IDSelector + faiss::IDSelectorBitmap sel(nt, bitset.bits.get()); + + // the quant of a search + const size_t nbatch = nq; + + // perform a baseline search + std::vector baseline_dis(k * nq, -1); + std::vector baseline_ids(k * nq, -1); + + faiss::SearchParameters baseline_params; + baseline_params.sel = &sel; + + StopWatch sw_baseline; + for (size_t p = 0; p < nq; p += nbatch) { + size_t p0 = std::min(nq, p + nbatch); + size_t np = p0 - p; + + baseline_index->search( + np, + xq.data() + p * d, + k, + baseline_dis.data() + k * p, + baseline_ids.data() + k * p, + &baseline_params); + } + double baseline_elapsed = sw_baseline.elapsed(); + + // perform an hnsw search + std::vector hnsw_dis(k * nq, -1); + std::vector hnsw_ids(k * nq, -1); + + faiss::SearchParametersHNSW hnsw_params; + hnsw_params.sel = &sel; + hnsw_params.efSearch = 64; + + StopWatch sw_hnsw; + for (size_t p = 0; p < nq; p += nbatch) { + size_t p0 = std::min(nq, p + nbatch); + size_t np = p0 - p; + + // hnsw_index->search(nq, xq.data(), k, hnsw_dis.data(), + // hnsw_ids.data(), &hnsw_params); + hnsw_index->search( + np, + xq.data() + p * d, + k, + hnsw_dis.data() + k * p, + hnsw_ids.data() + k * p, + &hnsw_params); + } + double hnsw_elapsed = sw_hnsw.elapsed(); + + // perform a cppcontrib/knowhere search + std::vector hnsw_candidate_dis(k * nq, -1); + std::vector hnsw_candidate_ids(k * nq, -1); + + faiss::cppcontrib::knowhere::SearchParametersHNSWWrapper + hnsw_candidate_params; + hnsw_candidate_params.sel = &sel; + hnsw_candidate_params.kAlpha = ((float)nt_real / nt) * 0.7f; + hnsw_candidate_params.efSearch = 64; + + faiss::cppcontrib::knowhere::IndexHNSWWrapper wrapper( + dynamic_cast(hnsw_index.get())); + + StopWatch sw_hnsw_candidate; + for (size_t p = 0; p < nq; p += nbatch) { + size_t p0 = std::min(nq, p + nbatch); + size_t np = p0 - p; + + // wrapper.search(nq, xq.data(), k, hnsw_candidate_dis.data(), + // hnsw_candidate_ids.data(), &hnsw_candidate_params); + wrapper.search( + np, + xq.data() + p * d, + k, + hnsw_candidate_dis.data() + k * p, + hnsw_candidate_ids.data() + k * p, + &hnsw_candidate_params); + } + double hnsw_candidate_elapsed = sw_hnsw_candidate.elapsed(); + + // compute the recall rate + const float recall_hnsw = + get_recall_rate(nq, k, baseline_ids, hnsw_ids); + const float recall_hnsw_candidate = + get_recall_rate(nq, k, baseline_ids, hnsw_candidate_ids); + + printf("perc=%zd, R_baseline=%f, R_candidate=%f, t_baseline=%f ms, t_candidate=%f ms\n", + percent, + recall_hnsw, + recall_hnsw_candidate, + hnsw_elapsed, + hnsw_candidate_elapsed); + } +} + +int main() { + // this takes time to eval + test(65536, 256, 1024, 64); + + return 0; +} \ No newline at end of file diff --git a/faiss/cppcontrib/knowhere/impl/Bruteforce.h b/faiss/cppcontrib/knowhere/impl/Bruteforce.h new file mode 100644 index 0000000000..c4241d6155 --- /dev/null +++ b/faiss/cppcontrib/knowhere/impl/Bruteforce.h @@ -0,0 +1,91 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace faiss { +namespace cppcontrib { +namespace knowhere { + +// C is CMax<> or CMin<> +template +void brute_force_search_impl( + const idx_t ntotal, + DistanceComputerT& __restrict qdis, + const FilterT& __restrict filter, + const idx_t k, + float* __restrict distances, + idx_t* __restrict labels) { + static_assert(std::is_same_v); + static_assert(std::is_same_v); + + auto max_heap = std::make_unique[]>(k); + idx_t n_added = 0; + for (idx_t idx = 0; idx < ntotal; ++idx) { + if (filter.is_member(idx)) { + const float distance = qdis(idx); + if (n_added < k) { + n_added += 1; + heap_push(n_added, max_heap.get(), distance, idx); + } else if (C::cmp(max_heap[0].first, distance)) { + heap_replace_top(k, max_heap.get(), distance, idx); + } + } + } + + const idx_t len = std::min(n_added, idx_t(k)); + for (idx_t i = 0; i < len; i++) { + labels[len - i - 1] = max_heap[0].second; + distances[len - i - 1] = max_heap[0].first; + + heap_pop(len - i, max_heap.get()); + } + + // fill leftovers + if (len < k) { + for (idx_t idx = len; idx < k; idx++) { + labels[idx] = -1; + distances[idx] = C::neutral(); + } + } +} + +// C is CMax<> or CMin<> +template +void brute_force_range_search_impl( + const idx_t ntotal, + DistanceComputerT& __restrict qdis, + const FilterT& __restrict filter, + ResultHandlerT& __restrict rres) { + for (idx_t idx = 0; idx < ntotal; ++idx) { + if (filter.is_member(idx)) { + const float distance = qdis(idx); + rres.add_result(distance, idx); + } + } +} + +} // namespace knowhere +} // namespace cppcontrib +} // namespace faiss diff --git a/faiss/cppcontrib/knowhere/impl/CountSizeIOWriter.h b/faiss/cppcontrib/knowhere/impl/CountSizeIOWriter.h new file mode 100644 index 0000000000..2d14ce4576 --- /dev/null +++ b/faiss/cppcontrib/knowhere/impl/CountSizeIOWriter.h @@ -0,0 +1,35 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#pragma once + +#include + +namespace faiss { +namespace cppcontrib { +namespace knowhere { + +// an IOWriter that just counts the number of bytes without writing anything. +struct CountSizeIOWriter : IOWriter { + size_t total_size = 0; + + size_t operator()(const void*, size_t size, size_t nitems) override { + total_size += size * nitems; + return nitems; + } +}; + +} // namespace knowhere +} // namespace cppcontrib +} // namespace faiss diff --git a/faiss/cppcontrib/knowhere/impl/HnswSearcher.h b/faiss/cppcontrib/knowhere/impl/HnswSearcher.h new file mode 100644 index 0000000000..3f3f550ea2 --- /dev/null +++ b/faiss/cppcontrib/knowhere/impl/HnswSearcher.h @@ -0,0 +1,578 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#pragma once + +// standard headers +#include +#include +#include +#include +#include +#include + +// Faiss-specific headers +#include +#include +#include +#include +#include +#include +#include +#include + +// Knowhere-specific headers +#include + +namespace faiss { +namespace cppcontrib { +namespace knowhere { + +namespace { + +// whether to track statistics +constexpr bool track_hnsw_stats = true; + +} // namespace + +// Accomodates all the search logic and variables. +/// * DistanceComputerT is responsible for computing distances +/// * GraphVisitorT records visited edges +/// * VisitedT is responsible for tracking visited nodes +/// * FilterT is resposible for filtering unneeded nodes +/// Interfaces of all templates are tweaked to accept standard Faiss structures +/// with dynamic dispatching. Custom Knowhere structures are also accepted. +template < + typename DistanceComputerT, + typename GraphVisitorT, + typename VisitedT, + typename FilterT> +struct v2_hnsw_searcher { + using storage_idx_t = faiss::HNSW::storage_idx_t; + using idx_t = faiss::idx_t; + + // hnsw structure. + // the reference is not owned. + const faiss::HNSW& hnsw; + + // computes distances. it already knows the query vector. + // the reference is not owned. + DistanceComputerT& qdis; + + // records visited edges. + // the reference is not owned. + GraphVisitorT& graph_visitor; + + // tracks the nodes that have been visited already. + // the reference is not owned. + VisitedT& visited_nodes; + + // a filter for disabled nodes. + // the reference is not owned. + const FilterT& filter; + + // parameter for the filtering + const float kAlpha; + + // custom parameters of HNSW search. + // the pointer is not owned. + const faiss::SearchParametersHNSW* params; + + // + v2_hnsw_searcher( + const faiss::HNSW& hnsw_, + DistanceComputerT& qdis_, + GraphVisitorT& graph_visitor_, + VisitedT& visited_nodes_, + const FilterT& filter_, + const float kAlpha_, + const faiss::SearchParametersHNSW* params_) + : hnsw{hnsw_}, + qdis{qdis_}, + graph_visitor{graph_visitor_}, + visited_nodes{visited_nodes_}, + filter{filter_}, + kAlpha{kAlpha_}, + params{params_} {} + + v2_hnsw_searcher(const v2_hnsw_searcher&) = delete; + v2_hnsw_searcher(v2_hnsw_searcher&&) = delete; + v2_hnsw_searcher& operator=(const v2_hnsw_searcher&) = delete; + v2_hnsw_searcher& operator=(v2_hnsw_searcher&&) = delete; + + // greedily update a nearest vector at a given level. + // * the update starts from the value in 'nearest'. + faiss::HNSWStats greedy_update_nearest( + const int level, + storage_idx_t& nearest, + float& d_nearest) { + faiss::HNSWStats stats; + + for (;;) { + storage_idx_t prev_nearest = nearest; + + size_t begin = 0; + size_t end = 0; + hnsw.neighbor_range(nearest, level, &begin, &end); + + auto update_with_candidate = [&](const storage_idx_t idx, + const float dis) { + graph_visitor.visit_edge(level, prev_nearest, idx, dis); + + if (dis < d_nearest) { + nearest = idx; + d_nearest = dis; + } + }; + + size_t counter = 0; + storage_idx_t saved_indices[4]; + + // visit neighbors + size_t count = 0; + for (size_t i = begin; i < end; i++) { + storage_idx_t v = hnsw.neighbors[i]; + if (v < 0) { + // no more neighbors + break; + } + + count += 1; + + saved_indices[counter] = v; + counter += 1; + + if (counter == 4) { + // evaluate 4x distances at once + float dis[4] = {0, 0, 0, 0}; + qdis.distances_batch_4( + saved_indices[0], + saved_indices[1], + saved_indices[2], + saved_indices[3], + dis[0], + dis[1], + dis[2], + dis[3]); + + for (size_t id4 = 0; id4 < 4; id4++) { + update_with_candidate(saved_indices[id4], dis[id4]); + } + + counter = 0; + } + } + + // process leftovers + for (size_t id4 = 0; id4 < counter; id4++) { + // evaluate a single distance + const float dis = qdis(saved_indices[id4]); + + update_with_candidate(saved_indices[id4], dis); + } + + // update stats + if (track_hnsw_stats) { + stats.ndis += count; + stats.nhops += 1; + } + + // we're done if there we no changes + if (nearest == prev_nearest) { + return stats; + } + } + } + + // no loops, just check neighbors of a single node. + template + faiss::HNSWStats evaluate_single_node( + const idx_t node_id, + const int level, + float& accumulated_alpha, + FuncAddCandidate func_add_candidate) { + // // unused + // bool do_dis_check = params ? params->check_relative_distance + // : hnsw.check_relative_distance; + + faiss::HNSWStats stats; + + size_t begin = 0; + size_t end = 0; + hnsw.neighbor_range(node_id, level, &begin, &end); + + // todo: add prefetch + size_t counter = 0; + storage_idx_t saved_indices[4]; + int saved_statuses[4]; + + size_t ndis = 0; + for (size_t j = begin; j < end; j++) { + const storage_idx_t v1 = hnsw.neighbors[j]; + + if (v1 < 0) { + // no more neighbors + break; + } + + // already visited? + if (visited_nodes.get(v1)) { + // yes, visited. + graph_visitor.visit_edge(level, node_id, v1, -1); + continue; + } + + // not visited. mark as visited. + visited_nodes.set(v1); + + // is the node disabled? + int status = knowhere::Neighbor::kValid; + if (!filter.is_member(v1)) { + // yes, disabled + status = knowhere::Neighbor::kInvalid; + + // sometimes, disabled nodes are allowed to be used + accumulated_alpha += kAlpha; + if (accumulated_alpha < 1.0f) { + continue; + } + + accumulated_alpha -= 1.0f; + } + + saved_indices[counter] = v1; + saved_statuses[counter] = status; + counter += 1; + + ndis += 1; + + if (counter == 4) { + // evaluate 4x distances at once + float dis[4] = {0, 0, 0, 0}; + qdis.distances_batch_4( + saved_indices[0], + saved_indices[1], + saved_indices[2], + saved_indices[3], + dis[0], + dis[1], + dis[2], + dis[3]); + + for (size_t id4 = 0; id4 < 4; id4++) { + // record a traversed edge + graph_visitor.visit_edge( + level, node_id, saved_indices[id4], dis[id4]); + + // add a record of visited nodes + knowhere::Neighbor nn( + saved_indices[id4], dis[id4], saved_statuses[id4]); + if (func_add_candidate(nn)) { +#if defined(USE_PREFETCH) + // TODO + // _mm_prefetch(get_linklist0(v), _MM_HINT_T0); +#endif + } + } + + counter = 0; + } + } + + // process leftovers + for (size_t id4 = 0; id4 < counter; id4++) { + // evaluate a single distance + const float dis = qdis(saved_indices[id4]); + + // record a traversed edge + graph_visitor.visit_edge(level, node_id, saved_indices[id4], dis); + + // add a record of visited + knowhere::Neighbor nn(saved_indices[id4], dis, saved_statuses[id4]); + if (func_add_candidate(nn)) { +#if defined(USE_PREFETCH) + // TODO + // _mm_prefetch(get_linklist0(v), _MM_HINT_T0); +#endif + } + } + + // update stats + if (track_hnsw_stats) { + stats.ndis = ndis; + stats.nhops = 1; + } + + // done + return stats; + } + + // perform the search on a given level. + // it is assumed that retset is initialized and contains the initial nodes. + faiss::HNSWStats search_on_a_level( + knowhere::NeighborSetDoublePopList& retset, + const int level, + knowhere::IteratorMinHeap* const __restrict disqualified = nullptr, + const float initial_accumulated_alpha = 1.0f) { + faiss::HNSWStats stats; + + // + float accumulated_alpha = initial_accumulated_alpha; + + // what to do with a accepted candidate + auto add_search_candidate = [&](const knowhere::Neighbor n) { + return retset.insert(n, disqualified); + }; + + // iterate while possible + while (retset.has_next()) { + // get a node to be processed + const knowhere::Neighbor neighbor = retset.pop(); + + // analyze its neighbors + faiss::HNSWStats local_stats = evaluate_single_node( + neighbor.id, + level, + accumulated_alpha, + add_search_candidate); + + // update stats + if (track_hnsw_stats) { + stats.combine(local_stats); + } + } + + // done + return stats; + } + + // traverse down to the level 0 + faiss::HNSWStats greedy_search_top_levels( + storage_idx_t& nearest, + float& d_nearest) { + faiss::HNSWStats stats; + + // iterate through upper levels + for (int level = hnsw.max_level; level >= 1; level--) { + // update the visitor + graph_visitor.visit_level(level); + + // alter the value of 'nearest' + faiss::HNSWStats local_stats = + greedy_update_nearest(level, nearest, d_nearest); + + // update stats + if (track_hnsw_stats) { + stats.combine(local_stats); + } + } + + return stats; + } + + // perform the search. + faiss::HNSWStats search( + const idx_t k, + float* __restrict distances, + idx_t* __restrict labels) { + faiss::HNSWStats stats; + + // is the graph empty? + if (hnsw.entry_point == -1) { + return stats; + } + + // grab some needed parameters + const int efSearch = params ? params->efSearch : hnsw.efSearch; + + // yes. + // greedy search on upper levels. + + // initialize the starting point. + storage_idx_t nearest = hnsw.entry_point; + float d_nearest = qdis(nearest); + + // iterate through upper levels + auto bottom_levels_stats = greedy_search_top_levels(nearest, d_nearest); + + // update stats + if (track_hnsw_stats) { + stats.combine(bottom_levels_stats); + } + + // level 0 search + + // update the visitor + graph_visitor.visit_level(0); + + // initialize the container for candidates + const idx_t n_candidates = std::max((idx_t)efSearch, k); + knowhere::NeighborSetDoublePopList retset(n_candidates); + + // initialize retset with a single 'nearest' point + { + if (!filter.is_member(nearest)) { + retset.insert(knowhere::Neighbor( + nearest, d_nearest, knowhere::Neighbor::kInvalid)); + } else { + retset.insert(knowhere::Neighbor( + nearest, d_nearest, knowhere::Neighbor::kValid)); + } + + visited_nodes[nearest] = true; + } + + // perform the search of the level 0. + faiss::HNSWStats local_stats = search_on_a_level(retset, 0); + + // todo: switch to brute-force in case of (retset.size() < k) + + // populate the result + const idx_t len = std::min((idx_t)retset.size(), k); + for (idx_t i = 0; i < len; i++) { + distances[i] = retset[i].distance; + labels[i] = (idx_t)retset[i].id; + } + + // update stats + if (track_hnsw_stats) { + stats.combine(local_stats); + } + + // done + return stats; + } + + faiss::HNSWStats range_search( + const float radius, + typename faiss::RangeSearchBlockResultHandler< + faiss::CMax>:: + SingleResultHandler* const __restrict rres) { + faiss::HNSWStats stats; + + // is the graph empty? + if (hnsw.entry_point == -1) { + return stats; + } + + // grab some needed parameters + const int efSearch = params ? params->efSearch : hnsw.efSearch; + + // yes. + // greedy search on upper levels. + + // initialize the starting point. + storage_idx_t nearest = hnsw.entry_point; + float d_nearest = qdis(nearest); + + // iterate through upper levels + auto bottom_levels_stats = greedy_search_top_levels(nearest, d_nearest); + + // update stats + if (track_hnsw_stats) { + stats.combine(bottom_levels_stats); + } + + // level 0 search + + // update the visitor + graph_visitor.visit_level(0); + + // initialize the container for candidates + const idx_t n_candidates = efSearch; + knowhere::NeighborSetDoublePopList retset(n_candidates); + + // initialize retset with a single 'nearest' point + { + if (!filter.is_member(nearest)) { + retset.insert(knowhere::Neighbor( + nearest, d_nearest, knowhere::Neighbor::kInvalid)); + } else { + retset.insert(knowhere::Neighbor( + nearest, d_nearest, knowhere::Neighbor::kValid)); + } + + visited_nodes[nearest] = true; + } + + // perform the search of the level 0. + faiss::HNSWStats local_stats = search_on_a_level(retset, 0); + + // update stats + if (track_hnsw_stats) { + stats.combine(local_stats); + } + + // select candidates that match our criteria + faiss::HNSWStats pick_stats; + + visited_nodes.clear(); + + std::queue> radius_queue; + for (size_t i = retset.size(); (i--) > 0;) { + const auto candidate = retset[i]; + if (candidate.distance < radius) { + radius_queue.push({candidate.distance, candidate.id}); + rres->add_result(candidate.distance, candidate.id); + + visited_nodes[candidate.id] = true; + } + } + + while (!radius_queue.empty()) { + auto current = radius_queue.front(); + radius_queue.pop(); + + size_t id_begin = 0; + size_t id_end = 0; + hnsw.neighbor_range(current.second, 0, &id_begin, &id_end); + + for (size_t id = id_begin; id < id_end; id++) { + const auto ngb = hnsw.neighbors[id]; + if (ngb == -1) { + break; + } + + if (visited_nodes[ngb]) { + continue; + } + + visited_nodes[ngb] = true; + + if (filter.is_member(ngb)) { + const float dis = qdis(ngb); + if (dis < radius) { + radius_queue.push({dis, ngb}); + rres->add_result(dis, ngb); + } + + if (track_hnsw_stats) { + pick_stats.ndis += 1; + } + } + } + } + + // update stats + if (track_hnsw_stats) { + stats.combine(pick_stats); + } + + return stats; + } +}; + +} // namespace knowhere +} // namespace cppcontrib +} // namespace faiss diff --git a/faiss/cppcontrib/knowhere/impl/Neighbor.h b/faiss/cppcontrib/knowhere/impl/Neighbor.h new file mode 100644 index 0000000000..305a8bedee --- /dev/null +++ b/faiss/cppcontrib/knowhere/impl/Neighbor.h @@ -0,0 +1,226 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace faiss { +namespace cppcontrib { +namespace knowhere { + +struct Neighbor { + static constexpr int kChecked = 0; + static constexpr int kValid = 1; + static constexpr int kInvalid = 2; + + unsigned id; + float distance; + int status; + + Neighbor() = default; + Neighbor(unsigned id, float distance, int status) + : id{id}, distance{distance}, status(status) {} + + inline bool operator<(const Neighbor& other) const { + return distance < other.distance; + } + + inline bool operator>(const Neighbor& other) const { + return distance > other.distance; + } +}; + +using IteratorMinHeap = std:: + priority_queue, std::greater>; + +template +class NeighborSetPopList { + private: + inline void insert_helper(const Neighbor& nbr, size_t pos) { + // move + std::memmove( + &data_[pos + 1], &data_[pos], (size_ - pos) * sizeof(Neighbor)); + if (size_ < capacity_) { + size_++; + } + + // insert + data_[pos] = nbr; + } + + public: + explicit NeighborSetPopList(size_t capacity) + : capacity_(capacity), data_(capacity + 1) {} + + inline bool insert( + const Neighbor nbr, + IteratorMinHeap* disqualified = nullptr) { + auto pos = + std::upper_bound(&data_[0], &data_[0] + size_, nbr) - &data_[0]; + if (pos >= capacity_) { + if (disqualified) { + disqualified->push(nbr); + } + return false; + } + if (size_ == capacity_ && disqualified) { + disqualified->push(data_[size_ - 1]); + } + insert_helper(nbr, pos); + if constexpr (need_save) { + if (pos < cur_) { + cur_ = pos; + } + } + return true; + } + + inline auto pop() -> Neighbor { + auto ret = data_[cur_]; + if constexpr (need_save) { + data_[cur_].status = Neighbor::kChecked; + cur_++; + while (cur_ < size_ && data_[cur_].status == Neighbor::kChecked) { + cur_++; + } + } else { + if (size_ > 1) { + std::memmove( + &data_[0], &data_[1], (size_ - 1) * sizeof(Neighbor)); + } + size_--; + } + return ret; + } + + inline auto has_next() const -> bool { + if constexpr (need_save) { + return cur_ < size_; + } else { + return size_ > 0; + } + } + + inline auto size() const -> size_t { + return size_; + } + + inline auto cur() const -> const Neighbor& { + if constexpr (need_save) { + return data_[cur_]; + } else { + return data_[0]; + } + } + + inline auto at_search_back_dist() const -> float { + if (size_ < capacity_) { + return std::numeric_limits::max(); + } + return data_[capacity_ - 1].distance; + } + + void clear() { + size_ = 0; + cur_ = 0; + } + + inline const Neighbor& operator[](size_t i) { + return data_[i]; + } + + private: + size_t capacity_ = 0, size_ = 0, cur_ = 0; + std::vector data_; +}; + +class NeighborSetDoublePopList { + public: + explicit NeighborSetDoublePopList(size_t capacity = 0) { + valid_ns_ = std::make_unique>(capacity); + invalid_ns_ = std::make_unique>(capacity); + } + + // will push any neighbor that does not fit into NeighborSet to + // disqualified. When searching for iterator, those points removed from + // NeighborSet may be qualified candidates as the iterator iterates, thus we + // need to retain instead of disposing them. + bool insert(const Neighbor& nbr, IteratorMinHeap* disqualified = nullptr) { + if (nbr.status == Neighbor::kValid) { + return valid_ns_->insert(nbr, disqualified); + } else { + if (nbr.distance < valid_ns_->at_search_back_dist()) { + return invalid_ns_->insert(nbr, disqualified); + } else if (disqualified) { + disqualified->push(nbr); + } + } + return false; + } + auto pop() -> Neighbor { + return pop_based_on_distance(); + } + + auto has_next() const -> bool { + return valid_ns_->has_next() || + (invalid_ns_->has_next() && + invalid_ns_->cur().distance < + valid_ns_->at_search_back_dist()); + } + + inline const Neighbor& operator[](size_t i) { + return (*valid_ns_)[i]; + } + + inline size_t size() const { + return valid_ns_->size(); + } + + private: + auto pop_based_on_distance() -> Neighbor { + bool hasCandNext = invalid_ns_->has_next(); + bool hasResNext = valid_ns_->has_next(); + + if (hasCandNext && hasResNext) { + return invalid_ns_->cur().distance < valid_ns_->cur().distance + ? invalid_ns_->pop() + : valid_ns_->pop(); + } + if (hasCandNext != hasResNext) { + return hasCandNext ? invalid_ns_->pop() : valid_ns_->pop(); + } + return {0, 0, Neighbor::kValid}; + } + + std::unique_ptr> valid_ns_ = nullptr; + std::unique_ptr> invalid_ns_ = nullptr; +}; + +static inline int InsertIntoPool(Neighbor* addr, intptr_t size, Neighbor nn) { + intptr_t p = std::lower_bound(addr, addr + size, nn) - addr; + std::memmove(addr + p + 1, addr + p, (size - p) * sizeof(Neighbor)); + addr[p] = nn; + return p; +} + +} // namespace knowhere +} // namespace cppcontrib +} // namespace faiss diff --git a/faiss/cppcontrib/knowhere/utils/Bitset.h b/faiss/cppcontrib/knowhere/utils/Bitset.h new file mode 100644 index 0000000000..f39c05441d --- /dev/null +++ b/faiss/cppcontrib/knowhere/utils/Bitset.h @@ -0,0 +1,124 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#pragma once + +#include +#include +#include +#include + +namespace faiss { +namespace cppcontrib { +namespace knowhere { + +struct Bitset final { + struct Proxy { + uint8_t& element; + uint8_t mask; + + inline Proxy(uint8_t& _element, const size_t _shift) + : element{_element}, mask(uint8_t(1) << _shift) {} + + inline operator bool() const { + return ((element & mask) != 0); + } + + inline Proxy& operator=(const bool value) { + if (value) { + set(); + } else { + reset(); + } + return *this; + } + + inline void set() { + element |= mask; + } + + inline void reset() { + element &= ~mask; + } + }; + + inline Bitset() {} + + // create an uncleared bitset + inline static Bitset create_uninitialized(const size_t initial_size) { + Bitset bitset; + + const size_t nbytes = (initial_size + 7) / 8; + + bitset.bits = std::make_unique(nbytes); + bitset.size = initial_size; + + return bitset; + } + + // create an initialized bitset + inline static Bitset create_cleared(const size_t initial_size) { + Bitset bitset = create_uninitialized(initial_size); + bitset.clear(); + + return bitset; + } + + Bitset(const Bitset&) = delete; + Bitset(Bitset&&) = default; + Bitset& operator=(const Bitset&) = delete; + Bitset& operator=(Bitset&&) = default; + + inline bool get(const size_t index) const { + return (bits[index >> 3] & (0x1 << (index & 0x7))); + } + + inline void set(const size_t index) { + bits[index >> 3] |= uint8_t(0x1 << (index & 0x7)); + } + + inline void reset(const size_t index) { + bits[index >> 3] &= (~uint8_t(0x1 << (index & 0x7))); + } + + inline const uint8_t* get_ptr(const size_t index) const { + return bits.get() + index / 8; + } + + inline uint8_t* get_ptr(const size_t index) { + return bits.get() + index / 8; + } + + inline void clear() { + const size_t nbytes = (size + 7) / 8; + std::memset(bits.get(), 0, nbytes); + } + + inline Proxy operator[](const size_t bit_idx) { + uint8_t& element = bits[bit_idx / 8]; + const size_t shift = bit_idx & 7; + return Proxy{element, shift}; + } + + inline bool operator[](const size_t bit_idx) const { + return get(bit_idx); + } + + std::unique_ptr bits; + size_t size = 0; +}; + +} // namespace knowhere +} // namespace cppcontrib +} // namespace faiss diff --git a/faiss/index_factory.cpp b/faiss/index_factory.cpp index 092df879bf..e7972ad227 100644 --- a/faiss/index_factory.cpp +++ b/faiss/index_factory.cpp @@ -767,7 +767,7 @@ std::unique_ptr index_factory_sub( } if (verbose) { - printf("after () normalization: %s %ld parenthesis indexes d=%d\n", + printf("after () normalization: %s %zd parenthesis indexes d=%d\n", description.c_str(), parenthesis_indexes.size(), d);