Skip to content

Commit

Permalink
wrappers and a customized HNSW search
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandr Guzhva <[email protected]>
  • Loading branch information
alexanderguzhva committed Oct 8, 2024
1 parent 2e6551f commit 59e2eca
Show file tree
Hide file tree
Showing 20 changed files with 2,223 additions and 97 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,5 @@ if(BUILD_TESTING)
endif()
endif()
endif()

add_subdirectory(faiss/cppcontrib/knowhere)
1 change: 0 additions & 1 deletion benchs/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,3 @@

add_executable(bench_ivf_selector EXCLUDE_FROM_ALL bench_ivf_selector.cpp)
target_link_libraries(bench_ivf_selector PRIVATE faiss)

55 changes: 55 additions & 0 deletions cmake/link_to_faiss_lib.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

function(link_to_faiss_lib target)
if(NOT FAISS_OPT_LEVEL STREQUAL "avx2" AND NOT FAISS_OPT_LEVEL STREQUAL "avx512" AND NOT FAISS_OPT_LEVEL STREQUAL "sve")
target_link_libraries(${target} PRIVATE faiss)
endif()

if(FAISS_OPT_LEVEL STREQUAL "avx2")
if(NOT WIN32)
target_compile_options(${target} PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-mavx2 -mfma>)
else()
target_compile_options(${target} PRIVATE $<$<COMPILE_LANGUAGE:CXX>:/arch:AVX2>)
endif()
target_link_libraries(${target} PRIVATE faiss_avx2)
endif()

if(FAISS_OPT_LEVEL STREQUAL "avx512")
if(NOT WIN32)
target_compile_options(${target} PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-mavx2 -mfma -mavx512f -mavx512f -mavx512cd -mavx512vl -mavx512dq -mavx512bw>)
else()
target_compile_options(${target} PRIVATE $<$<COMPILE_LANGUAGE:CXX>:/arch:AVX512>)
endif()
target_link_libraries(${target} PRIVATE faiss_avx512)
endif()

if(FAISS_OPT_LEVEL STREQUAL "sve")
if(NOT WIN32)
if("${CMAKE_CXX_FLAGS} ${CMAKE_CXX_FLAGS_DEBUG} " MATCHES "(^| )-march=native")
# Do nothing, expect SVE to be enabled by -march=native
elseif("${CMAKE_CXX_FLAGS} ${CMAKE_CXX_FLAGS_DEBUG} " MATCHES "(^| )(-march=armv[0-9]+(\\.[1-9]+)?-[^+ ](\\+[^+$ ]+)*)")
# Add +sve
target_compile_options(${target} PRIVATE $<$<AND:$<COMPILE_LANGUAGE:CXX>,$<CONFIG:DEBUG>>:${CMAKE_MATCH_2}+sve>)
elseif(NOT "${CMAKE_CXX_FLAGS} ${CMAKE_CXX_FLAGS_DEBUG} " MATCHES "(^| )-march=armv")
# No valid -march, so specify -march=armv8-a+sve as the default
target_compile_options(${target} PRIVATE $<$<AND:$<COMPILE_LANGUAGE:CXX>,$<CONFIG:DEBUG>>:-march=armv8-a+sve>)
endif()
if("${CMAKE_CXX_FLAGS} ${CMAKE_CXX_FLAGS_RELEASE} " MATCHES "(^| )-march=native")
# Do nothing, expect SVE to be enabled by -march=native
elseif("${CMAKE_CXX_FLAGS} ${CMAKE_CXX_FLAGS_RELEASE} " MATCHES "(^| )(-march=armv[0-9]+(\\.[1-9]+)?-[^+ ](\\+[^+$ ]+)*)")
# Add +sve
target_compile_options(${target} PRIVATE $<$<AND:$<COMPILE_LANGUAGE:CXX>,$<CONFIG:RELEASE>>:${CMAKE_MATCH_2}+sve>)
elseif(NOT "${CMAKE_CXX_FLAGS} ${CMAKE_CXX_FLAGS_RELEASE} " MATCHES "(^| )-march=armv")
# No valid -march, so specify -march=armv8-a+sve as the default
target_compile_options(${target} PRIVATE $<$<AND:$<COMPILE_LANGUAGE:CXX>,$<CONFIG:RELEASE>>:-march=armv8-a+sve>)
endif()
else()
# TODO: support Windows
endif()
target_link_libraries(${target} PRIVATE faiss_sve)
endif()
endfunction()
29 changes: 29 additions & 0 deletions faiss/cppcontrib/knowhere/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
cmake_minimum_required(VERSION 3.24.0)

project(faiss_cppcontrib_knowhere)

find_package(OpenMP REQUIRED)

set(SRC_FAISS_CPPCONTRIB_KNOWHERE
IndexBruteForceWrapper.cpp
IndexHNSWWrapper.cpp
IndexWrapper.cpp
)

include(${PROJECT_SOURCE_DIR}/../../../cmake/link_to_faiss_lib.cmake)

add_library(faiss_cppcontrib_knowhere ${SRC_FAISS_CPPCONTRIB_KNOWHERE})
link_to_faiss_lib(faiss_cppcontrib_knowhere)

target_include_directories(faiss_cppcontrib_knowhere PRIVATE
${PROJECT_SOURCE_DIR}/../../..
)

add_executable(bench_hnsw_knowhere benchs/bench_hnsw_knowhere.cpp)

link_to_faiss_lib(bench_hnsw_knowhere)
target_link_libraries(bench_hnsw_knowhere PRIVATE faiss_cppcontrib_knowhere OpenMP::OpenMP_CXX)

target_include_directories(bench_hnsw_knowhere PRIVATE
${PROJECT_SOURCE_DIR}/../../..
)
223 changes: 223 additions & 0 deletions faiss/cppcontrib/knowhere/IndexBruteForceWrapper.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
// Copyright (C) 2019-2024 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License. You may obtain a copy of
// the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License.

#include <faiss/cppcontrib/knowhere/IndexBruteForceWrapper.h>

#include <algorithm>
#include <memory>

#include <faiss/Index.h>
#include <faiss/MetricType.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/DistanceComputer.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/IDSelector.h>
#include <faiss/impl/ResultHandler.h>

#include <faiss/cppcontrib/knowhere/impl/Bruteforce.h>

namespace faiss {
namespace cppcontrib {
namespace knowhere {

IndexBruteForceWrapper::IndexBruteForceWrapper(Index* underlying_index)
: IndexWrapper{underlying_index} {}

void IndexBruteForceWrapper::search(
idx_t n,
const float* __restrict x,
idx_t k,
float* __restrict distances,
idx_t* __restrict labels,
const SearchParameters* __restrict params) const {
FAISS_THROW_IF_NOT(k > 0);

idx_t check_period =
InterruptCallback::get_period_hint(index->d * index->ntotal);

for (idx_t i0 = 0; i0 < n; i0 += check_period) {
idx_t i1 = std::min(i0 + check_period, n);

#pragma omp parallel if (i1 - i0 > 1)
{
std::unique_ptr<DistanceComputer> dis(
index->get_distance_computer());

#pragma omp for schedule(guided)
for (idx_t i = i0; i < i1; i++) {
// prepare the query
dis->set_query(x + i * index->d);

// allocate heap
idx_t* const __restrict local_ids = labels + i * index->d;
float* const __restrict local_distances =
distances + i * index->d;

// set up a filter
IDSelector* __restrict sel =
(params == nullptr) ? nullptr : params->sel;

if (is_similarity_metric(index->metric_type)) {
using C = CMin<float, idx_t>;

if (sel == nullptr) {
// Compiler is expected to de-virtualize virtual method
// calls
IDSelectorAll sel_all;
brute_force_search_impl<
C,
DistanceComputer,
IDSelectorAll>(
index->ntotal,
*dis,
sel_all,
k,
local_distances,
local_ids);
} else {
brute_force_search_impl<
C,
DistanceComputer,
IDSelector>(
index->ntotal,
*dis,
*sel,
k,
local_distances,
local_ids);
}
} else {
using C = CMax<float, idx_t>;

if (sel == nullptr) {
// Compiler is expected to de-virtualize virtual method
// calls
IDSelectorAll sel_all;
brute_force_search_impl<
C,
DistanceComputer,
IDSelectorAll>(
index->ntotal,
*dis,
sel_all,
k,
local_distances,
local_ids);
} else {
brute_force_search_impl<
C,
DistanceComputer,
IDSelector>(
index->ntotal,
*dis,
*sel,
k,
local_distances,
local_ids);
}
}
}
}

InterruptCallback::check();
}
}

void IndexBruteForceWrapper::range_search(
idx_t n,
const float* __restrict x,
float radius,
RangeSearchResult* __restrict result,
const SearchParameters* __restrict params) const {
using RH_min = RangeSearchBlockResultHandler<CMax<float, int64_t>>;
using RH_max = RangeSearchBlockResultHandler<CMin<float, int64_t>>;
RH_min bres_min(result, radius);
RH_max bres_max(result, radius);

idx_t check_period =
InterruptCallback::get_period_hint(index->d * index->ntotal);

for (idx_t i0 = 0; i0 < n; i0 += check_period) {
idx_t i1 = std::min(i0 + check_period, n);

#pragma omp parallel if (i1 - i0 > 1)
{
std::unique_ptr<DistanceComputer> dis(
index->get_distance_computer());

typename RH_min::SingleResultHandler res_min(bres_min);
typename RH_max::SingleResultHandler res_max(bres_max);

#pragma omp for schedule(guided)
for (idx_t i = i0; i < i1; i++) {
// prepare the query
dis->set_query(x + i * index->d);

// set up a filter
IDSelector* __restrict sel =
(params == nullptr) ? nullptr : params->sel;

if (is_similarity_metric(index->metric_type)) {
res_max.begin(i);

if (sel == nullptr) {
// Compiler is expected to de-virtualize virtual method
// calls
IDSelectorAll sel_all;

brute_force_range_search_impl<
typename RH_max::SingleResultHandler,
DistanceComputer,
IDSelectorAll>(
index->ntotal, *dis, sel_all, res_max);
} else {
brute_force_range_search_impl<
typename RH_max::SingleResultHandler,
DistanceComputer,
IDSelector>(index->ntotal, *dis, *sel, res_max);
}

res_max.end();
} else {
res_min.begin(i);

if (sel == nullptr) {
// Compiler is expected to de-virtualize virtual method
// calls
IDSelectorAll sel_all;

brute_force_range_search_impl<
typename RH_min::SingleResultHandler,
DistanceComputer,
IDSelectorAll>(
index->ntotal, *dis, sel_all, res_min);
} else {
brute_force_range_search_impl<
typename RH_min::SingleResultHandler,
DistanceComputer,
IDSelector>(index->ntotal, *dis, *sel, res_min);
}

res_min.end();
}
}
}

InterruptCallback::check();
}
}

} // namespace knowhere
} // namespace cppcontrib
} // namespace faiss
49 changes: 49 additions & 0 deletions faiss/cppcontrib/knowhere/IndexBruteForceWrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (C) 2019-2024 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License. You may obtain a copy of
// the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License.

#pragma once

#include <faiss/Index.h>

#include <faiss/cppcontrib/knowhere/IndexWrapper.h>

namespace faiss {
namespace cppcontrib {
namespace knowhere {

// override a search procedure to perform a brute-force search.
struct IndexBruteForceWrapper : IndexWrapper {
IndexBruteForceWrapper(Index* underlying_index);

/// entry point for search
void search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
const SearchParameters* params) const override;

/// entry point for range search
void range_search(
idx_t n,
const float* x,
float radius,
RangeSearchResult* result,
const SearchParameters* params) const override;
};

} // namespace knowhere
} // namespace cppcontrib
} // namespace faiss
Loading

0 comments on commit 59e2eca

Please sign in to comment.