From c76186c386a7da10a70151280604a36d23532557 Mon Sep 17 00:00:00 2001 From: Owen Green Date: Thu, 31 Oct 2024 09:26:50 +0000 Subject: [PATCH] Enhance/skmeans convergence (#285) * Normalize input data to unit sphere * Set assigned embedding cells to `1` before computing the means, otherwise convergence is hindered. --- include/algorithms/public/SKMeans.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/algorithms/public/SKMeans.hpp b/include/algorithms/public/SKMeans.hpp index 0ddf9699d..0fe41640a 100644 --- a/include/algorithms/public/SKMeans.hpp +++ b/include/algorithms/public/SKMeans.hpp @@ -33,7 +33,8 @@ class SKMeans : public KMeans using namespace Eigen; using namespace _impl; assert(!mTrained || (dataset.pointSize() == mDims && mK == k)); - MatrixXd dataPoints = asEigen(dataset.getData()); + MatrixXd dataPoints = + asEigen(dataset.getData()).colwise().normalized(); MatrixXd dataPointsT = dataPoints.transpose(); if (mTrained) { mAssignments = assignClusters(dataPointsT);} else @@ -87,9 +88,8 @@ class SKMeans : public KMeans { for (index i = 0; i < mAssignments.cols(); i++) { - double val = mEmbedding(mAssignments(i), i); mEmbedding.col(i).setZero(); - mEmbedding(mAssignments(i), i) = val; + mEmbedding(mAssignments(i), i) = 1.0; } }