diff --git a/include/rocksdb/options.h b/include/rocksdb/options.h index b27f53b4a849..3d36fcd4715b 100644 --- a/include/rocksdb/options.h +++ b/include/rocksdb/options.h @@ -1940,6 +1940,23 @@ struct ReadOptions { // Default: false bool allow_unprepared_value = false; + // The maximum number of neighbors K to return when performing a + // K-nearest-neighbors vector similarity search. The number of neighbors + // returned can be smaller if there are not enough vectors in the inverted + // lists probed. Only applicable to FAISS IVF secondary indices. See also + // `SecondaryIndex::NewIterator` and `similarity_search_probes` below. + // + // Default: 10 + size_t similarity_search_neighbors = 10; + + // The number of inverted lists to probe when performing a K-nearest-neighbors + // vector similarity search. Only applicable to FAISS IVF secondary indices. + // See also `SecondaryIndex::NewIterator` and `similarity_search_neighbors` + // above. + // + // Default: 1 + size_t similarity_search_probes = 1; + // *** END options only relevant to iterators or scans *** // *** BEGIN options for RocksDB internal use only *** diff --git a/include/rocksdb/utilities/secondary_index.h b/include/rocksdb/utilities/secondary_index.h index 4b66da6226ff..7921a70afaf3 100644 --- a/include/rocksdb/utilities/secondary_index.h +++ b/include/rocksdb/utilities/secondary_index.h @@ -6,10 +6,13 @@ #pragma once +#include #include #include #include +#include "rocksdb/iterator.h" +#include "rocksdb/options.h" #include "rocksdb/rocksdb_namespace.h" #include "rocksdb/slice.h" #include "rocksdb/status.h" @@ -97,6 +100,26 @@ class SecondaryIndex { const Slice& previous_column_value, std::optional>* secondary_value) const = 0; + + // Create an iterator that can be used to query the secondary index by calling + // the Seek API with a search target. The exact semantics of the returned + // iterator depend on the index and are implementation-specific. In the most + // common case, the search target is a primary column value, and the iterator + // returns all primary keys that have the given column value; however, other + // semantics are also possible. For example, in the case of vector similarity + // search, the search target is a vector, and the iterator returns similar + // vectors from the index. + // + // The returned iterator is expected to expose primary keys (i.e. the + // secondary key prefix is expected to be removed). The iterator is expected + // to support the Seek API (see above) but not the SeekToFirst, SeekToLast, + // and SeekForPrev APIs. + // + // The input parameter underlying_it is used to read secondary index entries, + // and is thus expected to be an iterator over the index's secondary column + // family. + virtual std::unique_ptr NewIterator( + const ReadOptions& read_options, Iterator* underlying_it) const = 0; }; } // namespace ROCKSDB_NAMESPACE diff --git a/utilities/secondary_index/faiss_ivf_index.cc b/utilities/secondary_index/faiss_ivf_index.cc index c419b98a2d1c..739995f0bc30 100644 --- a/utilities/secondary_index/faiss_ivf_index.cc +++ b/utilities/secondary_index/faiss_ivf_index.cc @@ -6,12 +6,155 @@ #include "utilities/secondary_index/faiss_ivf_index.h" #include +#include +#include #include "faiss/invlists/InvertedLists.h" +#include "util/autovector.h" #include "util/coding.h" namespace ROCKSDB_NAMESPACE { +class FaissIVFIndex::KNNIterator : public Iterator { + public: + KNNIterator(faiss::IndexIVF* index, Iterator* underlying_it, size_t k, + size_t probes) + : index_(index), + underlying_it_(underlying_it), + k_(k), + probes_(probes), + distances_(k, 0.0f), + labels_(k, -1), + pos_(0) { + assert(index_); + assert(underlying_it_); + } + + Iterator* GetUnderlyingIterator() const { return underlying_it_; } + + faiss::idx_t AddKey(std::string&& key) { + keys_.emplace_back(std::move(key)); + + return static_cast(keys_.size()) - 1; + } + + bool Valid() const override { + assert(labels_.size() == k_); + assert(distances_.size() == k_); + + return status_.ok() && pos_ >= 0 && pos_ < k_ && labels_[pos_] >= 0; + } + + void SeekToFirst() override { + status_ = + Status::NotSupported("SeekToFirst not supported for FaissIVFIndex"); + } + + void SeekToLast() override { + status_ = + Status::NotSupported("SeekToLast not supported for FaissIVFIndex"); + } + + void Seek(const Slice& target) override { + distances_.assign(k_, 0.0f); + labels_.assign(k_, -1); + status_ = Status::OK(); + pos_ = 0; + keys_.clear(); + + faiss::SearchParametersIVF params; + params.nprobe = probes_; + params.inverted_list_context = this; + + constexpr faiss::idx_t n = 1; + + try { + index_->search(n, reinterpret_cast(target.data()), k_, + distances_.data(), labels_.data(), ¶ms); + } catch (const std::exception& e) { + status_ = Status::InvalidArgument(e.what()); + } + } + + void SeekForPrev(const Slice& /* target */) override { + status_ = + Status::NotSupported("SeekForPrev not supported for FaissIVFIndex"); + } + + void Next() override { + assert(Valid()); + + ++pos_; + } + + void Prev() override { + assert(Valid()); + + --pos_; + } + + Status status() const override { return status_; } + + Slice key() const override { + assert(Valid()); + assert(labels_[pos_] >= 0); + assert(labels_[pos_] < keys_.size()); + + return keys_[labels_[pos_]]; + } + + Slice value() const override { + assert(Valid()); + + return Slice(); + } + + const WideColumns& columns() const override { + assert(Valid()); + + return kNoWideColumns; + } + + Slice timestamp() const override { + assert(Valid()); + + return Slice(); + } + + Status GetProperty(std::string prop_name, std::string* prop) override { + if (!prop) { + return Status::InvalidArgument("No property pointer provided"); + } + + if (!Valid()) { + return Status::InvalidArgument("Iterator is not valid"); + } + + if (prop_name == kPropertyName_) { + *prop = std::to_string(distances_[pos_]); + return Status::OK(); + } + + return Iterator::GetProperty(std::move(prop_name), prop); + } + + private: + faiss::IndexIVF* index_; + Iterator* underlying_it_; + size_t k_; + size_t probes_; + std::vector distances_; + std::vector labels_; + Status status_; + faiss::idx_t pos_; + autovector keys_; + + static const std::string kPropertyName_; +}; + +const std::string FaissIVFIndex::KNNIterator::kPropertyName_ = + "rocksdb.faiss.ivf.index.distance"; + class FaissIVFIndex::Adapter : public faiss::InvertedLists { public: Adapter(size_t num_lists, size_t code_size) @@ -36,14 +179,13 @@ class FaissIVFIndex::Adapter : public faiss::InvertedLists { return nullptr; } - // Iterator-based read interface; not yet implemented + // Iterator-based read interface faiss::InvertedListsIterator* get_iterator( - size_t /* list_no */, - void* /* inverted_list_context */ = nullptr) const override { - // TODO: implement this + size_t list_no, void* inverted_list_context = nullptr) const override { + KNNIterator* const it = static_cast(inverted_list_context); + assert(it); - assert(false); - return nullptr; + return new IteratorAdapter(it, list_no, code_size); } // Write interface; only add_entry is implemented/required for now @@ -80,6 +222,77 @@ class FaissIVFIndex::Adapter : public faiss::InvertedLists { void resize(size_t /* list_no */, size_t /* new_size */) override { assert(false); } + + private: + class IteratorAdapter : public faiss::InvertedListsIterator { + public: + IteratorAdapter(KNNIterator* it, size_t list_no, size_t code_size) + : it_(it), + underlying_it_(it->GetUnderlyingIterator()), + prefix_(FaissIVFIndex::SerializeLabel(list_no)), + prefix_slice_(prefix_), + code_size_(code_size) { + assert(it_); + assert(underlying_it_); + + // FIXME: here we rely on the empty Slice being less than any other one, + // which is true for e.g. BytewiseComparator but not in general + underlying_it_->Seek(prefix_slice_); + Update(); + } + + bool is_available() const override { return id_and_codes_.has_value(); } + + void next() override { + underlying_it_->Next(); + Update(); + } + + std::pair get_id_and_codes() override { + assert(is_available()); + + return *id_and_codes_; + } + + private: + void Update() { + id_and_codes_.reset(); + + if (!underlying_it_->Valid()) { + return; + } + + Slice key = underlying_it_->key(); + if (!key.starts_with(prefix_slice_)) { + return; + } + + if (!underlying_it_->PrepareValue()) { + throw std::runtime_error( + "Failed to prepare value during iteration in FaissIVFIndex"); + } + + const Slice& value = underlying_it_->value(); + if (value.size() != code_size_) { + throw std::runtime_error( + "Code with unexpected size encountered during iteration in " + "FaissIVFIndex"); + } + + key.remove_prefix(prefix_slice_.size()); + + const faiss::idx_t id = it_->AddKey(key.ToString()); + + id_and_codes_.emplace(id, reinterpret_cast(value.data())); + } + + KNNIterator* it_; + Iterator* underlying_it_; + std::string prefix_; + Slice prefix_slice_; + size_t code_size_; + std::optional> id_and_codes_; + }; }; std::string FaissIVFIndex::SerializeLabel(faiss::idx_t label) { @@ -105,6 +318,7 @@ FaissIVFIndex::FaissIVFIndex(std::unique_ptr&& index, assert(index_); assert(index_->quantizer); + index_->parallel_mode = 0; index_->replace_invlists(adapter_.get()); } @@ -203,7 +417,7 @@ Status FaissIVFIndex::GetSecondaryValue( if (code_str.size() != index_->code_size) { return Status::InvalidArgument( - "Unexpected code returned by fine quantizer"); + "Code with unexpected size returned by fine quantizer"); } secondary_value->emplace(std::move(code_str)); @@ -211,4 +425,11 @@ Status FaissIVFIndex::GetSecondaryValue( return Status::OK(); } +std::unique_ptr FaissIVFIndex::NewIterator( + const ReadOptions& read_options, Iterator* it) const { + return std::make_unique(index_.get(), it, + read_options.similarity_search_neighbors, + read_options.similarity_search_probes); +} + } // namespace ROCKSDB_NAMESPACE diff --git a/utilities/secondary_index/faiss_ivf_index.h b/utilities/secondary_index/faiss_ivf_index.h index 956dba7762ed..97807b355610 100644 --- a/utilities/secondary_index/faiss_ivf_index.h +++ b/utilities/secondary_index/faiss_ivf_index.h @@ -44,7 +44,11 @@ class FaissIVFIndex : public SecondaryIndex { std::optional>* secondary_value) const override; + std::unique_ptr NewIterator(const ReadOptions& read_options, + Iterator* underlying_it) const override; + private: + class KNNIterator; class Adapter; static std::string SerializeLabel(faiss::idx_t label); diff --git a/utilities/secondary_index/faiss_ivf_index_test.cc b/utilities/secondary_index/faiss_ivf_index_test.cc index 5d2008a47a7c..1215cc618c3d 100644 --- a/utilities/secondary_index/faiss_ivf_index_test.cc +++ b/utilities/secondary_index/faiss_ivf_index_test.cc @@ -33,8 +33,6 @@ TEST(FaissIVFIndexTest, Basic) { index->train(num_vectors, embeddings.data()); - index->nprobe = 2; - const std::string db_name = test::PerThreadDBPath("faiss_ivf_index_test"); EXPECT_OK(DestroyDB(db_name, Options())); @@ -113,6 +111,52 @@ TEST(FaissIVFIndexTest, Basic) { ASSERT_OK(it->status()); ASSERT_EQ(num_found, num_vectors); } + + ReadOptions read_options; + read_options.similarity_search_neighbors = 8; + read_options.similarity_search_probes = num_lists; + + std::unique_ptr underlying_it(db->NewIterator(read_options, cfh2)); + + std::unique_ptr it = + txn_db_options.secondary_indices.back()->NewIterator(read_options, + underlying_it.get()); + + float distance = 0.0f; + + auto get_distance_prop = [&]() { + std::string distance_str; + ASSERT_OK( + it->GetProperty("rocksdb.faiss.ivf.index.distance", &distance_str)); + ASSERT_EQ( + std::from_chars(distance_str.data(), + distance_str.data() + distance_str.size(), distance) + .ec, + std::errc()); + }; + + // Search for the first vector; we expect to find the vector itself as the + // closest match, since we're doing exhaustive search + it->Seek(Slice(reinterpret_cast(embeddings.data()), + dim * sizeof(float))); + get_distance_prop(); + ASSERT_TRUE(it->Valid()); + ASSERT_EQ(it->key(), "0"); + ASSERT_EQ(distance, 0.0f); + + float prev_distance = distance; + + size_t num_found = 1; + for (it->Next(); it->Valid(); it->Next()) { + get_distance_prop(); + ASSERT_GE(distance, prev_distance); + + prev_distance = distance; + ++num_found; + } + + ASSERT_OK(it->status()); + ASSERT_EQ(num_found, read_options.similarity_search_neighbors); } } // namespace ROCKSDB_NAMESPACE diff --git a/utilities/transactions/transaction_test.cc b/utilities/transactions/transaction_test.cc index e51572fa4005..2f2909d74387 100644 --- a/utilities/transactions/transaction_test.cc +++ b/utilities/transactions/transaction_test.cc @@ -26,6 +26,7 @@ #include "test_util/testharness.h" #include "test_util/testutil.h" #include "test_util/transaction_test_util.h" +#include "util/overload.h" #include "util/random.h" #include "util/string_util.h" #include "utilities/merge_operators.h" @@ -6290,9 +6291,9 @@ TEST_P(TransactionTest, DuplicateKeys) { } } delete cf_handle; - } // with_commit_batch - } // do_rollback - } // do_prepare + } + } + } if (!options.unordered_write) { // Also test with max_successive_merges > 0. max_successive_merges will not @@ -8084,7 +8085,121 @@ TEST_P(TransactionTest, SecondaryIndex) { return Status::OK(); } + std::unique_ptr NewIterator( + const ReadOptions& /* read_options */, + Iterator* underlying_it) const override { + return std::make_unique(this, underlying_it); + } + private: + class FooIterator : public Iterator { + public: + FooIterator(const SecondaryIndex* index, Iterator* underlying_it) + : index_(index), underlying_it_(underlying_it) { + assert(index_); + assert(underlying_it_); + } + + bool Valid() const override { + return status_.ok() && underlying_it_->Valid() && + underlying_it_->key().starts_with(prefix_); + } + + void SeekToFirst() override { + status_ = Status::NotSupported("SeekToFirst"); + } + + void SeekToLast() override { + status_ = Status::NotSupported("SeekToLast"); + } + + void Seek(const Slice& target) override { + status_ = Status::OK(); + + std::variant prefix; + + const Status s = + index_->GetSecondaryKeyPrefix(Slice(), target, &prefix); + if (!s.ok()) { + status_ = s; + return; + } + + prefix_ = std::visit( + overload{ + [](const Slice& value) -> std::string { + return value.ToString(); + }, + [](const std::string& value) -> std::string { return value; }}, + prefix); + + underlying_it_->Seek(prefix_); + } + + void SeekForPrev(const Slice& /* target */) override { + status_ = Status::NotSupported("SeekForPrev"); + } + + void Next() override { + assert(Valid()); + + underlying_it_->Next(); + } + + void Prev() override { + assert(Valid()); + + underlying_it_->Prev(); + } + + bool PrepareValue() override { + assert(Valid()); + + return underlying_it_->PrepareValue(); + } + + Status status() const override { + if (!status_.ok()) { + return status_; + } + + return underlying_it_->status(); + } + + Slice key() const override { + assert(Valid()); + + Slice key = underlying_it_->key(); + key.remove_prefix(prefix_.size()); + + return key; + } + + Slice value() const override { + assert(Valid()); + + return underlying_it_->value(); + } + + const WideColumns& columns() const override { + assert(Valid()); + + return underlying_it_->columns(); + } + + Slice timestamp() const override { + assert(Valid()); + + return Slice(); + } + + private: + const SecondaryIndex* index_; + Iterator* underlying_it_; + Status status_; + std::string prefix_; + }; + ColumnFamilyHandle* primary_cfh_{}; ColumnFamilyHandle* secondary_cfh_{}; }; @@ -8199,6 +8314,31 @@ TEST_P(TransactionTest, SecondaryIndex) { ASSERT_OK(it->status()); } + { + std::unique_ptr underlying_it( + db->NewIterator(ReadOptions(), cfh2)); + std::unique_ptr it( + index->NewIterator(ReadOptions(), underlying_it.get())); + + it->Seek("box"); // last character used for indexing: x + ASSERT_TRUE(it->Valid()); + ASSERT_EQ(it->key(), "key3"); + ASSERT_EQ(it->value(), "zab"); + + it->Next(); + ASSERT_TRUE(it->Valid()); + ASSERT_EQ(it->key(), "key4"); + ASSERT_EQ(it->value(), "xuuq"); + + it->Next(); + ASSERT_FALSE(it->Valid()); + ASSERT_OK(it->status()); + + it->Seek("toy"); // last character used for indexing: y + ASSERT_FALSE(it->Valid()); + ASSERT_OK(it->status()); + } + // Make some updates to the key-values indexed above through the database // interface (i.e. using implicit transactions) @@ -8273,6 +8413,31 @@ TEST_P(TransactionTest, SecondaryIndex) { ASSERT_FALSE(it->Valid()); ASSERT_OK(it->status()); } + + { + std::unique_ptr underlying_it( + db->NewIterator(ReadOptions(), cfh2)); + std::unique_ptr it( + index->NewIterator(ReadOptions(), underlying_it.get())); + + it->Seek("bot"); // last character used for indexing: t + ASSERT_TRUE(it->Valid()); + ASSERT_EQ(it->key(), "key1"); + ASSERT_EQ(it->value(), "tluarg"); + + it->Next(); + ASSERT_FALSE(it->Valid()); + ASSERT_OK(it->status()); + + it->Seek("toy"); // last character used for indexing: y + ASSERT_TRUE(it->Valid()); + ASSERT_EQ(it->key(), "key3"); + ASSERT_EQ(it->value(), "ylprag"); + + it->Next(); + ASSERT_FALSE(it->Valid()); + ASSERT_OK(it->status()); + } } TEST_F(TransactionDBTest, CollapseKey) {