Skip to content

Commit

Permalink
Fixing unit test for Faiss due to faiss upgrade. (opensearch-project#951
Browse files Browse the repository at this point in the history
)

Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v authored Jul 4, 2023
1 parent b84250e commit d55eccb
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 16 deletions.
4 changes: 2 additions & 2 deletions DEVELOPER_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ In addition to this, the plugin has been tested with JDK 17, and this JDK versio

#### CMake

The plugin requires that cmake >= 3.23.1 is installed in order to build the JNI libraries.
The plugin requires that cmake >= 3.23.3 is installed in order to build the JNI libraries.

One easy way to install on mac or linux is to use pip:
```bash
pip install cmake==3.23.1
pip install cmake==3.23.3
```

#### Faiss Dependencies
Expand Down
16 changes: 8 additions & 8 deletions jni/tests/faiss_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ using ::testing::Return;

TEST(FaissCreateIndexTest, BasicAssertions) {
// Define the data
faiss::Index::idx_t numIds = 200;
std::vector<faiss::Index::idx_t> ids;
faiss::idx_t numIds = 200;
std::vector<faiss::idx_t> ids;
std::vector<std::vector<float>> vectors;
int dim = 2;
for (int64_t i = 0; i < numIds; ++i) {
Expand Down Expand Up @@ -70,8 +70,8 @@ TEST(FaissCreateIndexTest, BasicAssertions) {

TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) {
// Define the data
faiss::Index::idx_t numIds = 100;
std::vector<faiss::Index::idx_t> ids;
faiss::idx_t numIds = 100;
std::vector<faiss::idx_t> ids;
std::vector<std::vector<float>> vectors;
int dim = 2;
for (int64_t i = 0; i < numIds; ++i) {
Expand Down Expand Up @@ -122,8 +122,8 @@ TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) {

TEST(FaissLoadIndexTest, BasicAssertions) {
// Define the data
faiss::Index::idx_t numIds = 100;
std::vector<faiss::Index::idx_t> ids;
faiss::idx_t numIds = 100;
std::vector<faiss::idx_t> ids;
std::vector<float> vectors;
int dim = 2;
for (int64_t i = 0; i < numIds; i++) {
Expand Down Expand Up @@ -174,8 +174,8 @@ TEST(FaissLoadIndexTest, BasicAssertions) {

TEST(FaissQueryIndexTest, BasicAssertions) {
// Define the index data
faiss::Index::idx_t numIds = 100;
std::vector<faiss::Index::idx_t> ids;
faiss::idx_t numIds = 100;
std::vector<faiss::idx_t> ids;
std::vector<float> vectors;
int dim = 16;
for (int64_t i = 0; i < numIds; i++) {
Expand Down
6 changes: 3 additions & 3 deletions jni/tests/test_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ faiss::Index *test_util::FaissLoadFromSerializedIndex(
}

faiss::IndexIDMap test_util::FaissAddData(faiss::Index *index,
std::vector<faiss::Index::idx_t> ids,
std::vector<faiss::idx_t> ids,
std::vector<float> dataset) {
faiss::IndexIDMap idMap = faiss::IndexIDMap(index);
idMap.add_with_ids(ids.size(), dataset.data(), ids.data());
Expand All @@ -251,11 +251,11 @@ faiss::Index *test_util::FaissLoadIndex(const std::string &indexPath) {
}

void test_util::FaissQueryIndex(faiss::Index *index, float *query, int k,
float *distances, faiss::Index::idx_t *ids) {
float *distances, faiss::idx_t *ids) {
index->search(1, query, k, distances, ids);
}

void test_util::FaissTrainIndex(faiss::Index *index, faiss::Index::idx_t n,
void test_util::FaissTrainIndex(faiss::Index *index, faiss::idx_t n,
const float *x) {
if (auto *indexIvf = dynamic_cast<faiss::IndexIVF *>(index)) {
if (indexIvf->quantizer_trains_alone == 2) {
Expand Down
6 changes: 3 additions & 3 deletions jni/tests/test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,17 @@ namespace test_util {
faiss::Index* FaissLoadFromSerializedIndex(std::vector<uint8_t>* indexSerial);

faiss::IndexIDMap FaissAddData(faiss::Index* index,
std::vector<faiss::Index::idx_t> ids,
std::vector<faiss::idx_t> ids,
std::vector<float> dataset);

void FaissWriteIndex(faiss::Index* index, const std::string& indexPath);

faiss::Index* FaissLoadIndex(const std::string& indexPath);

void FaissQueryIndex(faiss::Index* index, float* query, int k, float* distances,
faiss::Index::idx_t* ids);
faiss::idx_t* ids);

void FaissTrainIndex(faiss::Index* index, faiss::Index::idx_t n,
void FaissTrainIndex(faiss::Index* index, faiss::idx_t n,
const float* x);

// -------------------------------------------------------------------------------
Expand Down

0 comments on commit d55eccb

Please sign in to comment.