Skip to content

Commit

Permalink
Enhance/skmeans convergence (#285)
Browse files Browse the repository at this point in the history
* Normalize input data to unit sphere

* Set assigned embedding cells to `1` before computing the means, otherwise convergence is hindered.
  • Loading branch information
weefuzzy authored Oct 31, 2024
1 parent bf9038b commit c76186c
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions include/algorithms/public/SKMeans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ class SKMeans : public KMeans
using namespace Eigen;
using namespace _impl;
assert(!mTrained || (dataset.pointSize() == mDims && mK == k));
MatrixXd dataPoints = asEigen<Matrix>(dataset.getData());
MatrixXd dataPoints =
asEigen<Matrix>(dataset.getData()).colwise().normalized();
MatrixXd dataPointsT = dataPoints.transpose();
if (mTrained) { mAssignments = assignClusters(dataPointsT);}
else
Expand Down Expand Up @@ -87,9 +88,8 @@ class SKMeans : public KMeans
{
for (index i = 0; i < mAssignments.cols(); i++)
{
double val = mEmbedding(mAssignments(i), i);
mEmbedding.col(i).setZero();
mEmbedding(mAssignments(i), i) = val;
mEmbedding(mAssignments(i), i) = 1.0;
}
}

Expand Down

0 comments on commit c76186c

Please sign in to comment.