Skip to content

Commit

Permalink
Support KNN searches for FAISS IVF indices (#13207)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #13207

Differential Revision: D67184936
  • Loading branch information
ltamasi authored and facebook-github-bot committed Dec 13, 2024
1 parent 2ff9f95 commit 3bddcbf
Show file tree
Hide file tree
Showing 6 changed files with 486 additions and 12 deletions.
17 changes: 17 additions & 0 deletions include/rocksdb/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -1940,6 +1940,23 @@ struct ReadOptions {
// Default: false
bool allow_unprepared_value = false;

// The maximum number of neighbors K to return when performing a
// K-nearest-neighbors vector similarity search. The number of neighbors
// returned can be smaller if there are not enough vectors in the inverted
// lists probed. Only applicable to FAISS IVF secondary indices. See also
// `SecondaryIndex::NewIterator` and `similarity_search_probes` below.
//
// Default: 10
size_t similarity_search_neighbors = 10;

// The number of inverted lists to probe when performing a K-nearest-neighbors
// vector similarity search. Only applicable to FAISS IVF secondary indices.
// See also `SecondaryIndex::NewIterator` and `similarity_search_neighbors`
// above.
//
// Default: 1
size_t similarity_search_probes = 1;

// *** END options only relevant to iterators or scans ***

// *** BEGIN options for RocksDB internal use only ***
Expand Down
23 changes: 23 additions & 0 deletions include/rocksdb/utilities/secondary_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@

#pragma once

#include <memory>
#include <optional>
#include <string>
#include <variant>

#include "rocksdb/iterator.h"
#include "rocksdb/options.h"
#include "rocksdb/rocksdb_namespace.h"
#include "rocksdb/slice.h"
#include "rocksdb/status.h"
Expand Down Expand Up @@ -97,6 +100,26 @@ class SecondaryIndex {
const Slice& previous_column_value,
std::optional<std::variant<Slice, std::string>>* secondary_value)
const = 0;

// Create an iterator that can be used to query the secondary index by calling
// the Seek API with a search target. The exact semantics of the returned
// iterator depend on the index and are implementation-specific. In the most
// common case, the search target is a primary column value, and the iterator
// returns all primary keys that have the given column value; however, other
// semantics are also possible. For example, in the case of vector similarity
// search, the search target is a vector, and the iterator returns similar
// vectors from the index.
//
// The returned iterator is expected to expose primary keys (i.e. the
// secondary key prefix is expected to be removed). The iterator is expected
// to support the Seek API (see above) but not the SeekToFirst, SeekToLast,
// and SeekForPrev APIs.
//
// The input parameter underlying_it is used to read secondary index entries,
// and is thus expected to be an iterator over the index's secondary column
// family.
virtual std::unique_ptr<Iterator> NewIterator(
const ReadOptions& read_options, Iterator* underlying_it) const = 0;
};

} // namespace ROCKSDB_NAMESPACE
235 changes: 228 additions & 7 deletions utilities/secondary_index/faiss_ivf_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,155 @@
#include "utilities/secondary_index/faiss_ivf_index.h"

#include <cassert>
#include <optional>
#include <utility>

#include "faiss/invlists/InvertedLists.h"
#include "util/autovector.h"
#include "util/coding.h"

namespace ROCKSDB_NAMESPACE {

class FaissIVFIndex::KNNIterator : public Iterator {
public:
KNNIterator(faiss::IndexIVF* index, Iterator* underlying_it, size_t k,
size_t probes)
: index_(index),
underlying_it_(underlying_it),
k_(k),
probes_(probes),
distances_(k, 0.0f),
labels_(k, -1),
pos_(0) {
assert(index_);
assert(underlying_it_);
}

Iterator* GetUnderlyingIterator() const { return underlying_it_; }

faiss::idx_t AddKey(std::string&& key) {
keys_.emplace_back(std::move(key));

return static_cast<faiss::idx_t>(keys_.size()) - 1;
}

bool Valid() const override {
assert(labels_.size() == k_);
assert(distances_.size() == k_);

return status_.ok() && pos_ >= 0 && pos_ < k_ && labels_[pos_] >= 0;
}

void SeekToFirst() override {
status_ =
Status::NotSupported("SeekToFirst not supported for FaissIVFIndex");
}

void SeekToLast() override {
status_ =
Status::NotSupported("SeekToLast not supported for FaissIVFIndex");
}

void Seek(const Slice& target) override {
distances_.assign(k_, 0.0f);
labels_.assign(k_, -1);
status_ = Status::OK();
pos_ = 0;
keys_.clear();

faiss::SearchParametersIVF params;
params.nprobe = probes_;
params.inverted_list_context = this;

constexpr faiss::idx_t n = 1;

try {
index_->search(n, reinterpret_cast<const float*>(target.data()), k_,
distances_.data(), labels_.data(), &params);
} catch (const std::exception& e) {
status_ = Status::InvalidArgument(e.what());
}
}

void SeekForPrev(const Slice& /* target */) override {
status_ =
Status::NotSupported("SeekForPrev not supported for FaissIVFIndex");
}

void Next() override {
assert(Valid());

++pos_;
}

void Prev() override {
assert(Valid());

--pos_;
}

Status status() const override { return status_; }

Slice key() const override {
assert(Valid());
assert(labels_[pos_] >= 0);
assert(labels_[pos_] < keys_.size());

return keys_[labels_[pos_]];
}

Slice value() const override {
assert(Valid());

return Slice();
}

const WideColumns& columns() const override {
assert(Valid());

return kNoWideColumns;
}

Slice timestamp() const override {
assert(Valid());

return Slice();
}

Status GetProperty(std::string prop_name, std::string* prop) override {
if (!prop) {
return Status::InvalidArgument("No property pointer provided");
}

if (!Valid()) {
return Status::InvalidArgument("Iterator is not valid");
}

if (prop_name == kPropertyName_) {
*prop = std::to_string(distances_[pos_]);
return Status::OK();
}

return Iterator::GetProperty(std::move(prop_name), prop);
}

private:
faiss::IndexIVF* index_;
Iterator* underlying_it_;
size_t k_;
size_t probes_;
std::vector<float> distances_;
std::vector<faiss::idx_t> labels_;
Status status_;
faiss::idx_t pos_;
autovector<std::string> keys_;

static const std::string kPropertyName_;
};

const std::string FaissIVFIndex::KNNIterator::kPropertyName_ =
"rocksdb.faiss.ivf.index.distance";

class FaissIVFIndex::Adapter : public faiss::InvertedLists {
public:
Adapter(size_t num_lists, size_t code_size)
Expand All @@ -36,14 +179,13 @@ class FaissIVFIndex::Adapter : public faiss::InvertedLists {
return nullptr;
}

// Iterator-based read interface; not yet implemented
// Iterator-based read interface
faiss::InvertedListsIterator* get_iterator(
size_t /* list_no */,
void* /* inverted_list_context */ = nullptr) const override {
// TODO: implement this
size_t list_no, void* inverted_list_context = nullptr) const override {
KNNIterator* const it = static_cast<KNNIterator*>(inverted_list_context);
assert(it);

assert(false);
return nullptr;
return new IteratorAdapter(it, list_no, code_size);
}

// Write interface; only add_entry is implemented/required for now
Expand Down Expand Up @@ -80,6 +222,77 @@ class FaissIVFIndex::Adapter : public faiss::InvertedLists {
void resize(size_t /* list_no */, size_t /* new_size */) override {
assert(false);
}

private:
class IteratorAdapter : public faiss::InvertedListsIterator {
public:
IteratorAdapter(KNNIterator* it, size_t list_no, size_t code_size)
: it_(it),
underlying_it_(it->GetUnderlyingIterator()),
prefix_(FaissIVFIndex::SerializeLabel(list_no)),
prefix_slice_(prefix_),
code_size_(code_size) {
assert(it_);
assert(underlying_it_);

// FIXME: here we rely on the empty Slice being less than any other one,
// which is true for e.g. BytewiseComparator but not in general
underlying_it_->Seek(prefix_slice_);
Update();
}

bool is_available() const override { return id_and_codes_.has_value(); }

void next() override {
underlying_it_->Next();
Update();
}

std::pair<faiss::idx_t, const uint8_t*> get_id_and_codes() override {
assert(is_available());

return *id_and_codes_;
}

private:
void Update() {
id_and_codes_.reset();

if (!underlying_it_->Valid()) {
return;
}

Slice key = underlying_it_->key();
if (!key.starts_with(prefix_slice_)) {
return;
}

if (!underlying_it_->PrepareValue()) {
throw std::runtime_error(
"Failed to prepare value during iteration in FaissIVFIndex");
}

const Slice& value = underlying_it_->value();
if (value.size() != code_size_) {
throw std::runtime_error(
"Code with unexpected size encountered during iteration in "
"FaissIVFIndex");
}

key.remove_prefix(prefix_slice_.size());

const faiss::idx_t id = it_->AddKey(key.ToString());

id_and_codes_.emplace(id, reinterpret_cast<const uint8_t*>(value.data()));
}

KNNIterator* it_;
Iterator* underlying_it_;
std::string prefix_;
Slice prefix_slice_;
size_t code_size_;
std::optional<std::pair<faiss::idx_t, const uint8_t*>> id_and_codes_;
};
};

std::string FaissIVFIndex::SerializeLabel(faiss::idx_t label) {
Expand All @@ -105,6 +318,7 @@ FaissIVFIndex::FaissIVFIndex(std::unique_ptr<faiss::IndexIVF>&& index,
assert(index_);
assert(index_->quantizer);

index_->parallel_mode = 0;
index_->replace_invlists(adapter_.get());
}

Expand Down Expand Up @@ -203,12 +417,19 @@ Status FaissIVFIndex::GetSecondaryValue(

if (code_str.size() != index_->code_size) {
return Status::InvalidArgument(
"Unexpected code returned by fine quantizer");
"Code with unexpected size returned by fine quantizer");
}

secondary_value->emplace(std::move(code_str));

return Status::OK();
}

std::unique_ptr<Iterator> FaissIVFIndex::NewIterator(
const ReadOptions& read_options, Iterator* it) const {
return std::make_unique<KNNIterator>(index_.get(), it,
read_options.similarity_search_neighbors,
read_options.similarity_search_probes);
}

} // namespace ROCKSDB_NAMESPACE
4 changes: 4 additions & 0 deletions utilities/secondary_index/faiss_ivf_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ class FaissIVFIndex : public SecondaryIndex {
std::optional<std::variant<Slice, std::string>>*
secondary_value) const override;

std::unique_ptr<Iterator> NewIterator(const ReadOptions& read_options,
Iterator* underlying_it) const override;

private:
class KNNIterator;
class Adapter;

static std::string SerializeLabel(faiss::idx_t label);
Expand Down
Loading

0 comments on commit 3bddcbf

Please sign in to comment.