Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing a loading layer in FAISS native engine. #2139

Merged
merged 10 commits into from
Oct 3, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
* Adds concurrent segment search support for mode auto [#2111](https://github.com/opensearch-project/k-NN/pull/2111)
* Introducing a loading layer in FAISS [#2033](https://github.com/opensearch-project/k-NN/issues/2033)
### Bug Fixes
* Add DocValuesProducers for releasing memory when close index [#1946](https://github.com/opensearch-project/k-NN/pull/1946)
### Infrastructure
Expand Down
128 changes: 128 additions & 0 deletions jni/include/faiss_stream_support.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

#ifndef OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H
#define OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H

#include "faiss/impl/io.h"

#include <jni.h>
#include <stdexcept>
#include <iostream>
#include <cstring>

namespace knn_jni { namespace stream {



/**
0ctopus13prime marked this conversation as resolved.
Show resolved Hide resolved
* This class contains Java IndexInputWithBuffer reference and calls its API to copy required bytes into a read buffer.
*/
class NativeEngineIndexInputMediator {
public:
// Expect IndexInputWithBuffer is given as `_indexInput`.
NativeEngineIndexInputMediator(JNIEnv * _env, jobject _indexInput)
: env(_env),
indexInput(_indexInput),
bufferArray((jbyteArray) (_env->GetObjectField(_indexInput, getBufferFieldId(_env)))),
copyBytesMethod(getCopyBytesMethod(_env)) {
}

void copyBytes(int32_t nbytes, uint8_t* destination) {
while (nbytes > 0) {
// Call `copyBytes` to read bytes as many as possible.
const auto readBytes =
env->CallIntMethod(indexInput, copyBytesMethod, nbytes);

// === Critical Section Start ===

// Get primitive array pointer, no copy is happening in OpenJDK.
jbyte* primitiveArray =
(jbyte*) env->GetPrimitiveArrayCritical(bufferArray, NULL);

// Copy Java bytes to C++ destination address.
std::memcpy(destination, primitiveArray, readBytes);

// Release the acquired primitive array pointer.
// JNI_ABORT tells JVM to directly free memory without copying back to Java byte[].
// Since we're merely copying data, we don't need to copying back.
env->ReleasePrimitiveArrayCritical(bufferArray, primitiveArray, JNI_ABORT);

// === Critical Section End ===

destination += readBytes;
nbytes -= readBytes;
} // End while
}

private:
static jclass getIndexInputWithBufferClass(JNIEnv * env) {
static jclass INDEX_INPUT_WITH_BUFFER_CLASS =
env->FindClass("org/opensearch/knn/index/util/IndexInputWithBuffer");
return INDEX_INPUT_WITH_BUFFER_CLASS;
}

static jmethodID getCopyBytesMethod(JNIEnv *env) {
static jmethodID COPY_METHOD_ID =
env->GetMethodID(getIndexInputWithBufferClass(env), "copyBytes", "(J)I");
return COPY_METHOD_ID;
}

static jfieldID getBufferFieldId(JNIEnv *env) {
static jfieldID BUFFER_FIELD_ID = env->GetFieldID(getIndexInputWithBufferClass(env), "buffer", "[B");
return BUFFER_FIELD_ID;
}

JNIEnv * env;

// `IndexInputWithBuffer` instance having `IndexInput` instance obtained from `Directory` for reading.
jobject indexInput;
jbyteArray bufferArray;
jmethodID copyBytesMethod;
}; // class NativeEngineIndexInputMediator



/**
* A glue component inheriting IOReader to be passed down to Faiss library.
* This will then indirectly call the mediator component and eventually read required bytes from Lucene's IndexInput.
*/
class FaissOpenSearchIOReader final : public faiss::IOReader {
public:
explicit FaissOpenSearchIOReader(NativeEngineIndexInputMediator* _mediator)
: faiss::IOReader(),
mediator(_mediator) {
name = "FaissOpenSearchIOReader";
}

size_t operator()(void* ptr, size_t size, size_t nitems) final {
const auto readBytes = size * nitems;
if (readBytes > 0) {
// Mediator calls IndexInput, then copy read bytes to `ptr`.
mediator->copyBytes(readBytes, (uint8_t *) ptr);
}
return nitems;
}

int filedescriptor() final {
throw std::runtime_error("filedescriptor() is not supported in FaissOpenSearchIOReader.");
}

private:
NativeEngineIndexInputMediator* mediator;
}; // class FaissOpenSearchIOReader



}
}

#endif //OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H
10 changes: 10 additions & 0 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,21 @@ namespace knn_jni {
// Return a pointer to the loaded index
jlong LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ);

// Loads an index with a reader implemented IOReader
//
// Returns a pointer of the loaded index
jlong LoadIndexWithStream(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, faiss::IOReader* ioReader);

// Load a binary index from indexPathJ into memory.
//
// Return a pointer to the loaded index
jlong LoadBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ);

// Loads a binary index with a reader implemented IOReader
//
// Returns a pointer of the loaded index
jlong LoadBinaryIndexWithStream(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, faiss::IOReader* ioReader);

// Check if a loaded index requires shared state
bool IsSharedIndexStateRequired(jlong indexPointerJ);

Expand Down
16 changes: 16 additions & 0 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex
(JNIEnv *, jclass, jstring);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: loadIndexWithStream
* Signature: (Lorg/opensearch/knn/index/util/IndexInputWithBuffer;)J
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndexWithStream
(JNIEnv *, jclass, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: loadBinaryIndex
Expand All @@ -136,6 +144,14 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndex
(JNIEnv *, jclass, jstring);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: loadBinaryIndexWithStream
* Signature: (Lorg/opensearch/knn/index/util/IndexInputWithBuffer;)J
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndexWithStream
(JNIEnv *, jclass, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: isSharedIndexStateRequired
Expand Down
28 changes: 28 additions & 0 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,20 @@ jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNI
return (jlong) indexReader;
}

jlong knn_jni::faiss_wrapper::LoadIndexWithStream(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, faiss::IOReader* ioReader) {
if (ioReader == nullptr) [[unlikely]] {
throw std::runtime_error("IOReader cannot be null");
}

faiss::Index* indexReader =
faiss::read_index(ioReader,
faiss::IO_FLAG_READ_ONLY
| faiss::IO_FLAG_PQ_SKIP_SDC_TABLE
| faiss::IO_FLAG_SKIP_PRECOMPUTE_TABLE);

return (jlong) indexReader;
}

jlong knn_jni::faiss_wrapper::LoadBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) {
if (indexPathJ == nullptr) {
throw std::runtime_error("Index path cannot be null");
Expand All @@ -436,6 +450,20 @@ jlong knn_jni::faiss_wrapper::LoadBinaryIndex(knn_jni::JNIUtilInterface * jniUti
return (jlong) indexReader;
}

jlong knn_jni::faiss_wrapper::LoadBinaryIndexWithStream(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, faiss::IOReader* ioReader) {
if (ioReader == nullptr) [[unlikely]] {
throw std::runtime_error("IOReader cannot be null");
}

faiss::IndexBinary* indexReader =
faiss::read_index_binary(ioReader,
faiss::IO_FLAG_READ_ONLY
| faiss::IO_FLAG_PQ_SKIP_SDC_TABLE
| faiss::IO_FLAG_SKIP_PRECOMPUTE_TABLE);

return (jlong) indexReader;
}

bool knn_jni::faiss_wrapper::IsSharedIndexStateRequired(jlong indexPointerJ) {
auto * index = reinterpret_cast<faiss::Index*>(indexPointerJ);
return isIndexIVFPQL2(index);
Expand Down
47 changes: 47 additions & 0 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "faiss_wrapper.h"
#include "jni_util.h"
#include "faiss_stream_support.h"

static knn_jni::JNIUtil jniUtil;
static const jint KNN_FAISS_JNI_VERSION = JNI_VERSION_1_1;
Expand Down Expand Up @@ -217,6 +218,29 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex(JNIEn
return NULL;
}

JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndexWithStream
(JNIEnv * env, jclass cls, jobject readStream)
{
try {
// Create a mediator locally.
// Note that `indexInput` is `IndexInputWithBuffer` type.
knn_jni::stream::NativeEngineIndexInputMediator mediator {env, readStream};

// Wrap the mediator with a glue code inheriting IOReader.
knn_jni::stream::FaissOpenSearchIOReader faissOpenSearchIOReader {&mediator};

// Pass IOReader to Faiss for loading vector index.
return knn_jni::faiss_wrapper::LoadIndexWithStream(
&jniUtil,
env,
&faissOpenSearchIOReader);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}

return NULL;
}

JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndex(JNIEnv * env, jclass cls, jstring indexPathJ)
{
try {
Expand All @@ -227,6 +251,29 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndex
return NULL;
}

JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndexWithStream
(JNIEnv * env, jclass cls, jobject readStream)
{
try {
// Create a mediator locally.
// Note that `indexInput` is `IndexInputWithBuffer` type.
knn_jni::stream::NativeEngineIndexInputMediator mediator {env, readStream};

// Wrap the mediator with a glue code inheriting IOReader.
knn_jni::stream::FaissOpenSearchIOReader faissOpenSearchIOReader {&mediator};

// Pass IOReader to Faiss for loading vector index.
return knn_jni::faiss_wrapper::LoadBinaryIndexWithStream(
&jniUtil,
env,
&faissOpenSearchIOReader);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}

return NULL;
}

JNIEXPORT jboolean JNICALL Java_org_opensearch_knn_jni_FaissService_isSharedIndexStateRequired
(JNIEnv * env, jclass cls, jlong indexPointerJ)
{
Expand Down
3 changes: 3 additions & 0 deletions src/main/java/org/opensearch/knn/index/KNNIndexShard.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SegmentReader;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.store.FilterDirectory;
import org.opensearch.common.lucene.Lucene;
Expand Down Expand Up @@ -89,11 +90,13 @@ public String getIndexName() {
*/
public void warmup() throws IOException {
log.info("[KNN] Warming up index: [{}]", getIndexName());
final Directory directory = indexShard.store().directory();
try (Engine.Searcher searcher = indexShard.acquireSearcher("knn-warmup")) {
getAllEngineFileContexts(searcher.getIndexReader()).forEach((engineFileContext) -> {
try {
nativeMemoryCacheManager.get(
new NativeMemoryEntryContext.IndexEntryContext(
directory,
engineFileContext.getIndexPath(),
NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(),
getParametersAtLoading(
Expand Down
Loading
Loading