Skip to content

Commit

Permalink
Expanded test coverage, added some more comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
LTLA committed Jun 13, 2024
1 parent bc31114 commit 2b153f8
Showing 1 changed file with 76 additions and 4 deletions.
80 changes: 76 additions & 4 deletions tests/src/Hnsw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ TEST_P(HnswTest, FindEuclidean) {
knncolle::SimpleMatrix mat(ndim, nobs, data.data());
knncolle_hnsw::HnswBuilder<> builder;
auto bptr = builder.build_unique(mat);
EXPECT_EQ(bptr->num_dimensions(), ndim);
EXPECT_EQ(bptr->num_observations(), nobs);
auto bsptr = bptr->initialize();

// Trying with a different interface type.
Expand Down Expand Up @@ -222,7 +224,7 @@ TEST(Hnsw, Constructor) {
auto def_opt = def.get_options();
EXPECT_NE(def_opt.num_links, 10000);

// Checking that this is respected in the other constructor.
// Checking that this is respected in the overloaded constructor.
def_opt.num_links = 1000;
knncolle_hnsw::HnswBuilder<> mutant(def_opt);
EXPECT_EQ(mutant.get_options().num_links, 1000);
Expand All @@ -239,9 +241,44 @@ TEST(Hnsw, EuclideanDouble) {
d = dist(rng);
}
}
knncolle::SimpleMatrix<int, int, double> mat(ndim, nobs, data.data());

// using a double as the InternalData_ to check that we dispatch correctly to a SquaredEuclideanDistance.
knncolle_hnsw::HnswBuilder<decltype(mat), double, double> builder;
auto bptr = builder.build_unique(mat);
auto bsptr = bptr->initialize();

std::vector<int> ires;
std::vector<double> dres;
for (int x = 0; x < nobs; ++x) {
bsptr->search(x, 10, &ires, &dres);

// Checking the distance to the most distant neighbor. Here we can afford to have some
// lower tolerances because everything's computed with doubles.
auto furthest = ires.back();
auto current = data.data() + x * ndim;
auto ptr = data.data() + furthest * ndim;
auto expected = knncolle::EuclideanDistance::raw_distance<double>(current, ptr, ndim);
EXPECT_LT(std::abs(knncolle::EuclideanDistance::normalize(expected) - dres.back()), 0.000001);
}
}

TEST(Hnsw, EuclideanNormalize) {
int ndim = 5;
int nobs = 100;
std::vector<double> data(ndim * nobs);
{
std::mt19937_64 rng(1001);
std::normal_distribution dist;
for (auto& d : data) {
d = dist(rng);
}
}
knncolle::SimpleMatrix<int, int, double> mat(ndim, nobs, data.data());
knncolle_hnsw::HnswBuilder<decltype(mat), double, double> builder;

// Checking that the normalization option is respected.
knncolle_hnsw::HnswBuilder<decltype(mat)> builder;
builder.get_options().distance_options.normalize = [&](float x) -> float { return x + 1; };
auto bptr = builder.build_unique(mat);
auto bsptr = bptr->initialize();

Expand All @@ -250,11 +287,46 @@ TEST(Hnsw, EuclideanDouble) {
for (int x = 0; x < nobs; ++x) {
bsptr->search(x, 10, &ires, &dres);

// Checking the distance to the most distant neighbor.
// Checking the distance to the most distant neighbor; using some more gentle tolerances again.
auto furthest = ires.back();
auto current = data.data() + x * ndim;
auto ptr = data.data() + furthest * ndim;
auto expected = knncolle::EuclideanDistance::raw_distance<double>(current, ptr, ndim);
EXPECT_LT(std::abs(knncolle::EuclideanDistance::normalize(expected) - dres.back()), 0.0001);
EXPECT_LT(std::abs((expected + 1) - dres.back()), 0.0001);
}
}

TEST(Hnsw, Duplicates) {
// Checking that the neighbor identification works correctly when there are
// so many duplicates that an observation doesn't get reported by HNSW in
// its own list of neighbors.
int ndim = 5;
int nobs = 100;
std::vector<double> data(ndim * nobs);
knncolle::SimpleMatrix<int, int, double> mat(ndim, nobs, data.data());

knncolle_hnsw::HnswBuilder<decltype(mat)> builder;
auto bptr = builder.build_unique(mat);
auto bsptr = bptr->initialize();

std::vector<int> ires, ires0;
std::vector<double> dres, dres0;

for (int x = 0; x < nobs; ++x) {
bsptr->search(x, 10, &ires, &dres);
EXPECT_EQ(ires.size(), 10);
for (const auto& ix : ires) { // self is not in there.
EXPECT_TRUE(ix != x);
}

EXPECT_EQ(dres.back(), 0);
EXPECT_EQ(dres.front(), 0);

// Same behavior with NULLs.
bsptr->search(x, 10, NULL, &dres0);
EXPECT_EQ(dres, dres0);
bsptr->search(x, 10, &ires0, NULL);
EXPECT_EQ(ires, ires0);
}
}

0 comments on commit 2b153f8

Please sign in to comment.