Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

faiss interface refactoring to support multiple methods #344

Closed
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
63a8a9f
Refactoring plugin code for supporting multiple engines
jmazanec15 Mar 9, 2021
56a4c73
Refactor spaceType to be flat in plugin
jmazanec15 Mar 9, 2021
4bbfc26
Refactor ANN scoring to support multiple engines
jmazanec15 Mar 9, 2021
7e6efd8
Refactor inner product score translation
jmazanec15 Mar 9, 2021
b16abea
Add method parameter for parsing method config
jmazanec15 Mar 18, 2021
16eaf35
Add support for faiss indices that require training
jmazanec15 Mar 30, 2021
96cb855
add PQ encoding support for faiss
jmazanec15 Mar 31, 2021
629d205
Modify training to use index data
jmazanec15 Mar 31, 2021
1aa25f9
Switch to debug log statements
jmazanec15 Mar 31, 2021
632d8dc
Add support for faiss flat index
jmazanec15 Apr 1, 2021
80f8f9e
Adjust training points to 5K
jmazanec15 Apr 1, 2021
705f794
Add support for extra parameters in jni and clean code
jmazanec15 Apr 6, 2021
786ca2d
Clean up lib versioning
jmazanec15 Apr 6, 2021
210a948
Remove unnecessary params from faiss jni
jmazanec15 Apr 6, 2021
45f80ed
Dont generate extra parameters for nmslib
jmazanec15 Apr 6, 2021
fa84683
Set default parameter values for faiss
jmazanec15 Apr 6, 2021
fd33d64
Refactor structure of engine functions
jmazanec15 Apr 7, 2021
d1417a9
Rename course to coarse
jmazanec15 Apr 7, 2021
fdcdb43
Support method context for nmslib hnsw parameters
jmazanec15 Apr 7, 2021
62cd44d
Pull strings out into constants
jmazanec15 Apr 7, 2021
952dfd6
Refactor spaceType passing logic
jmazanec15 Apr 7, 2021
cafc36e
Fix case for null parameter
jmazanec15 Apr 8, 2021
aa77e34
Rename FAISSLibVersion to FaissLibVersion
jmazanec15 Apr 8, 2021
dc381fb
Improve parsing implementation
jmazanec15 Apr 9, 2021
608c8c1
Make training limits configurable
jmazanec15 Apr 9, 2021
6695cb8
Allow pq for flat faiss index
jmazanec15 Apr 12, 2021
dcc05d1
Add extra params for hnsw
jmazanec15 Apr 14, 2021
75e43d8
Minor clean up
jmazanec15 Apr 16, 2021
3dce15d
Refactor engine logic
jmazanec15 Apr 16, 2021
99ebf56
Minor refactoring to validation logic
jmazanec15 Apr 19, 2021
76e19de
Add method interface uTs and fix broken old ones
jmazanec15 Apr 20, 2021
821acf4
Add null parameters check
jmazanec15 Apr 20, 2021
4144732
Fix query stat counter
jmazanec15 Apr 20, 2021
0c51442
Fix IT for faiss
jmazanec15 Apr 20, 2021
799e816
Remove index name from jni functions
jmazanec15 Apr 23, 2021
fbcb76d
Minor refactoring
jmazanec15 Apr 28, 2021
18ca977
Move parameter and methodcomponent into individual file
jmazanec15 Apr 28, 2021
22501e3
Use builder to build MethodComponent
jmazanec15 Apr 28, 2021
5139e7a
Add builder for KNNMethod
jmazanec15 Apr 28, 2021
428c7ef
Refactor to improve readability
jmazanec15 Apr 28, 2021
5a01566
Refactor KNNMethodContext to improve readability
jmazanec15 Apr 28, 2021
861a0df
Separate out MethodComponentContext from KNNMethodContext
jmazanec15 Apr 29, 2021
080b4d6
Clean up library code
jmazanec15 Apr 29, 2021
d8b6cc2
Rename putParameter to addParameter
jmazanec15 Apr 29, 2021
523595f
Minor fixes
jmazanec15 Apr 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -18,150 +18,222 @@

#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <string>
#include <vector>
#include <sys/time.h>
#include <omp.h>

#include "faiss/index_factory.h"
#include "faiss/MetaIndexes.h"
#include "faiss/index_io.h"
#include "faiss/IndexHNSW.h"
#include "faiss/IndexIVFFlat.h"


using std::string;
using std::vector;

// mapMetric is used to map a string from the plugin to a faiss metric. All translation should be done via this map
std::unordered_map<string, faiss::MetricType> mapMetric = {
{"l2", faiss::METRIC_L2},
{"innerproduct", faiss::METRIC_INNER_PRODUCT}
};

void TrainIndex(faiss::Index * index, faiss::Index::idx_t n, const float* x) {
if (auto * indexIvf = dynamic_cast<faiss::IndexIVF*>(index)) {
if (indexIvf->quantizer_trains_alone == 2) {
TrainIndex(indexIvf->quantizer, n, x);
}
indexIvf->make_direct_map();
}

if (!index->is_trained) {
index->train(n, x);
}
}

void SetExtraParameters(JNIEnv *env, jobject parameterMap, faiss::Index * index) {
// Here, we parse parameterMap, which is a java Map<String, Object>. In order to implement this, I referred to
// https://stackoverflow.com/questions/4844022/jni-create-hashmap

// Load all of the class and methods to iterate over a map
jclass mapClass = knn_jni::FindClass(env, "java/util/Map");
jmethodID entrySet = knn_jni::FindMethod(env, mapClass, "entrySet", "()Ljava/util/Set;");

jobject parameterEntrySet = env->CallObjectMethod(parameterMap, entrySet);
knn_jni::HasExceptionInStack(env, "Unable to call \"entrySet\" method on \"java/util/Map\"");

jclass setClass = knn_jni::FindClass(env, "java/util/Set");

jmethodID iterator = knn_jni::FindMethod(env, setClass, "iterator", "()Ljava/util/Iterator;");

jclass iteratorClass = knn_jni::FindClass(env, "java/util/Iterator");

jobject iter = env->CallObjectMethod(parameterEntrySet, iterator);
knn_jni::HasExceptionInStack(env, "Call to \"iterator\" method failed");

jmethodID hasNext = knn_jni::FindMethod(env, iteratorClass, "hasNext", "()Z");
jmethodID next = knn_jni::FindMethod(env, iteratorClass, "next", "()Ljava/lang/Object;");

jclass entryClass = knn_jni::FindClass(env, "java/util/Map$Entry");

jmethodID getKey = knn_jni::FindMethod(env, entryClass, "getKey", "()Ljava/lang/Object;");

jmethodID getValue = knn_jni::FindMethod(env, entryClass, "getValue", "()Ljava/lang/Object;");

jclass integerClass = knn_jni::FindClass(env, "java/lang/Integer");

jmethodID intValue = knn_jni::FindMethod(env, integerClass, "intValue", "()I");

// Iterate over the entry Set
jobject entry;
std::string key;
jobject value;
while (env->CallBooleanMethod(iter, hasNext)) {
entry = env->CallObjectMethod(iter, next);
knn_jni::HasExceptionInStack(env, "Could not call \"next\" method");

key = knn_jni::GetStringJenv(env, (jstring) env->CallObjectMethod(entry, getKey));
knn_jni::HasExceptionInStack(env, "Could not call \"getKey\" method");

value = env->CallObjectMethod(entry, getValue);
knn_jni::HasExceptionInStack(env, "Could not call \"getValue\" method");

if (auto * indexIvf = dynamic_cast<faiss::IndexIVF*>(index)) {
if (key == "nprobes") {
if (env->IsInstanceOf(value, integerClass)) {
indexIvf->nprobe = env->CallIntMethod(value, intValue);
knn_jni::HasExceptionInStack(env, "Could not call \"intValue\" method on Integer");
} else {
throw std::runtime_error("Cannot call IntMethod on non-integer class");
}
} else if (key == "coarse_quantizer" && indexIvf->quantizer != nullptr) {
SetExtraParameters(env, value, indexIvf->quantizer);
}
env->DeleteLocalRef(value);
}

if (auto * indexHnsw = dynamic_cast<faiss::IndexHNSW*>(index)) {
if (key == "ef_construction") {
if (env->IsInstanceOf(value, integerClass)) {
indexHnsw->hnsw.efConstruction = env->CallIntMethod(value, intValue);
knn_jni::HasExceptionInStack(env, "Could not call \"intValue\" method on Integer");
} else {
throw std::runtime_error("Cannot call IntMethod on non-integer class");
}
} else if (key == "ef_search") {
if (env->IsInstanceOf(value, integerClass)) {
indexHnsw->hnsw.efSearch = env->CallIntMethod(value, intValue);
knn_jni::HasExceptionInStack(env, "Could not call \"intValue\" method on Integer");
} else {
throw std::runtime_error("Cannot call IntMethod on non-integer class");
}
}
env->DeleteLocalRef(value);
}
env->DeleteLocalRef(entry);
}
env->DeleteLocalRef(parameterEntrySet);
knn_jni::HasExceptionInStack(env, "Could not call \"hasNext\" method");
}

/**
* Method: saveIndex
*
*/
JNIEXPORT void JNICALL Java_com_amazon_opendistroforelasticsearch_knn_index_faiss_v165_KNNFaissIndex_saveIndex
(JNIEnv* env, jclass cls, jintArray ids, jobjectArray vectors, jstring indexPath, jobjectArray algoParams, jstring spaceType)
JNIEXPORT void JNICALL Java_com_amazon_opendistroforelasticsearch_knn_index_faiss_v165_KNNFaissIndex_save
(JNIEnv* env, jclass cls, jintArray ids, jobjectArray vectors, jstring indexPath, jobject parameterMap,
jstring spaceType, jstring indexDescription, jint trainingDatasetSizeLimit, jint minimumDatapoints)
{
vector<int64_t> idVector;
vector<float> dataset;
vector<string> paramsList;
//TODO we can support other FAISS index in the future, may be paramsList can add index=xxxx
string indexDescription = "HNSW32";
faiss::MetricType metric = faiss::METRIC_L2;
std::unique_ptr<faiss::Index> indexWriter;
int dim = 0;
try {
//---- ids
int* object_ids = NULL;
object_ids = env->GetIntArrayElements(ids, 0);
for(int i = 0; i < env->GetArrayLength(ids); ++i) {
idVector.push_back(object_ids[i]);
}
env->ReleaseIntArrayElements(ids, object_ids, 0);
knn_jni::HasExceptionInStack(env);

//---- vectors
for (int i = 0; i < env->GetArrayLength(vectors); ++i) {
jfloatArray vectorArray = (jfloatArray)env->GetObjectArrayElement(vectors, i);
float* vector = env->GetFloatArrayElements(vectorArray, 0);
dim = env->GetArrayLength(vectorArray);
for(int j = 0; j < dim; ++j) {
dataset.push_back(vector[j]);
}
env->ReleaseFloatArrayElements(vectorArray, vector, 0);
}
knn_jni::HasExceptionInStack(env);
vector<int64_t> idVector;
vector<float> dataset;
vector<string> paramsList;

//---- indexPath
const char *indexString = env->GetStringUTFChars(indexPath, 0);
string indexPathString(indexString);
env->ReleaseStringUTFChars(indexPath, indexString);
faiss::MetricType metric = faiss::METRIC_L2;
std::unique_ptr<faiss::Index> indexWriter;
int dim = 0;
try {
//---- ids
int* object_ids = nullptr;
object_ids = env->GetIntArrayElements(ids, 0);
for(int i = 0; i < env->GetArrayLength(ids); ++i) {
idVector.push_back(object_ids[i]);
}
env->ReleaseIntArrayElements(ids, object_ids, 0);
knn_jni::HasExceptionInStack(env);

//---- algoParams
int paramsCount = env->GetArrayLength(algoParams);
for (int i=0; i<paramsCount; i++) {
jstring param = (jstring) (env->GetObjectArrayElement(algoParams, i));
const char *rawString = env->GetStringUTFChars(param, 0);
paramsList.push_back(rawString);
//---- vectors
for (int i = 0; i < env->GetArrayLength(vectors); ++i) {
auto vectorArray = (jfloatArray)env->GetObjectArrayElement(vectors, i);
float* vector = env->GetFloatArrayElements(vectorArray, 0);
dim = env->GetArrayLength(vectorArray);
for(int j = 0; j < dim; ++j) {
dataset.push_back(vector[j]);
}
env->ReleaseFloatArrayElements(vectorArray, vector, 0);
}
knn_jni::HasExceptionInStack(env);

int M = 32;
if (sscanf(rawString, "M=%d", &M) == 1) {
indexDescription="HNSW"+std::to_string(M);
}
env->ReleaseStringUTFChars(param, rawString);
//---- indexPath
const char *indexString = env->GetStringUTFChars(indexPath, 0);
string indexPathString(indexString);
env->ReleaseStringUTFChars(indexPath, indexString);
knn_jni::HasExceptionInStack(env);

}
//---- space
const char *spaceTypeCStr = env->GetStringUTFChars(spaceType, 0);
string spaceTypeString(spaceTypeCStr);
env->ReleaseStringUTFChars(spaceType, spaceTypeCStr);
knn_jni::HasExceptionInStack(env);
// space mapping faiss::MetricType
if(mapMetric.find(spaceTypeString) != mapMetric.end()) {
metric = mapMetric[spaceTypeString];
}

//---- Create IndexWriter from faiss index_factory
// If data is less than a certain amount, just create a flat index
if (idVector.size() < (int) minimumDatapoints) {
indexWriter.reset(faiss::index_factory(dim, "Flat", metric));
} else {
std::string description = knn_jni::GetStringJenv(env, indexDescription);
indexWriter.reset(faiss::index_factory(dim, description.c_str(), metric));
}

//---- space
const char *spaceTypeCStr = env->GetStringUTFChars(spaceType, 0);
string spaceTypeString(spaceTypeCStr);
env->ReleaseStringUTFChars(spaceType, spaceTypeCStr);
knn_jni::HasExceptionInStack(env);
// space mapping faiss::MetricType
if(mapMetric.find(spaceTypeString) != mapMetric.end()) {
metric = mapMetric[spaceTypeString];
}
// Add extra parameters that cant be configured with the index factory
SetExtraParameters(env, parameterMap, indexWriter.get());
env->DeleteLocalRef(parameterMap);

//---- Create IndexWriter from faiss index_factory
indexWriter.reset(faiss::index_factory(dim, indexDescription.data(), metric));

//Preparation And TODO Verify IndexWriter
//Some Param Can not Create from IndexFactory, Like HNSW efSearch and efCOnstruction
//----FOR HNSW 1st PARAM: M(HNSW32->M=32), efConstruction, efSearch
if(indexDescription.find("HNSW") != std::string::npos) {
for(int i = 0; i < paramsCount; ++i) {
const string& param = paramsList[i];
int efConstruction = 40; //default
int efSearch = 16;//default
if(param.find("efConstruction") != std::string::npos &&
sscanf(param.data(), "efConstruction=%d", &efConstruction) == 1) {
faiss::IndexHNSW* ihp = reinterpret_cast<faiss::IndexHNSW*>(indexWriter.get());
ihp->hnsw.efConstruction = efConstruction;
} else if (param.find("efSearch") != std::string::npos &&
sscanf(param.data(), "efSearch=%d", &efSearch) == 1){
faiss::IndexHNSW* ihp = reinterpret_cast<faiss::IndexHNSW*>(indexWriter.get());
ihp->hnsw.efSearch = efSearch;
}
}
}
//---- Do Index
if(!indexWriter->is_trained) {
if (idVector.size() <= (int) trainingDatasetSizeLimit) {
TrainIndex(indexWriter.get(), idVector.size(), dataset.data());
} else {
vector<float>::const_iterator first = dataset.begin();
vector<float>::const_iterator last = dataset.begin() + ((int) trainingDatasetSizeLimit)*dim;
vector<float> subDataVector(first, last);
TrainIndex(indexWriter.get(), (int) trainingDatasetSizeLimit, subDataVector.data());
}
}

//---- Do Index
//----- 1. Train
if(!indexWriter->is_trained) {
//TODO if we use like PQ, we have to train dataset
// but when a lucene segment only one document, it
// can not train the data.
}
//----- 2. Add IDMap
// default all use self defined IndexIDMap cause some class no add_with_ids
faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get());
idMap.add_with_ids(idVector.size(), dataset.data(), idVector.data());
//----- 2. Add IDMap
faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get());
idMap.add_with_ids(idVector.size(), dataset.data(), idVector.data());

//----- 3. WriteIndex
faiss::write_index(&idMap, indexPathString.c_str());

//Explicit delete object
faiss::Index* indexPointer = indexWriter.release();
if(indexPointer) delete indexPointer;

delete indexPointer;
}
catch(...) {
faiss::Index* indexPointer = indexWriter.release();
if(indexPointer) delete indexPointer;
delete indexPointer;
knn_jni::CatchCppExceptionAndThrowJava(env);
}
}


JNIEXPORT jobjectArray JNICALL Java_com_amazon_opendistroforelasticsearch_knn_index_faiss_v165_KNNFaissIndex_queryIndex
(JNIEnv* env, jclass cls, jlong indexPointer, jfloatArray queryVector, jint k)
JNIEXPORT jobjectArray JNICALL Java_com_amazon_opendistroforelasticsearch_knn_index_faiss_v165_KNNFaissIndex_query
(JNIEnv* env, jclass cls, jlong indexPointer, jfloatArray queryVector, jint k)
{
faiss::Index *indexReader = nullptr;
try {
Expand Down Expand Up @@ -203,8 +275,8 @@ std::unordered_map<string, faiss::MetricType> mapMetric = {
return NULL;
}

JNIEXPORT jlong JNICALL Java_com_amazon_opendistroforelasticsearch_knn_index_faiss_v165_KNNFaissIndex_init
(JNIEnv* env, jclass cls, jstring indexPath, jobjectArray algoParams, jstring spaceType)
JNIEXPORT jlong JNICALL Java_com_amazon_opendistroforelasticsearch_knn_index_faiss_v165_KNNFaissIndex_init
(JNIEnv* env, jclass cls, jstring indexPath)
{

faiss::Index* indexReader = nullptr;
Expand All @@ -228,8 +300,8 @@ std::unordered_map<string, faiss::MetricType> mapMetric = {
* When autoclose class do close, then delete the pointer
* Method GC pointer
*/
JNIEXPORT void JNICALL Java_com_amazon_opendistroforelasticsearch_knn_index_faiss_v165_KNNFaissIndex_gc
(JNIEnv* env, jclass cls, jlong indexPointer)
JNIEXPORT void JNICALL Java_com_amazon_opendistroforelasticsearch_knn_index_faiss_v165_KNNFaissIndex_gc
(JNIEnv* env, jclass cls, jlong indexPointer)
{
try {
faiss::Index *indexWrapper = reinterpret_cast<faiss::Index*>(indexPointer);
Expand All @@ -246,7 +318,8 @@ std::unordered_map<string, faiss::MetricType> mapMetric = {
* Method: Global Init
*
*/
JNIEXPORT void JNICALL Java_com_amazon_opendistroforelasticsearch_knn_index_faiss_v165_KNNFaissIndex_initLibrary(JNIEnv *, jclass)
JNIEXPORT void JNICALL Java_com_amazon_opendistroforelasticsearch_knn_index_faiss_v165_KNNFaissIndex_initLibrary
(JNIEnv *, jclass)
{
//set thread 1 cause ES has Search thread
//TODO make it different at search and write
Expand Down
Loading