From 8b781eba4aeb30f06a0fce10df0d0b29b21e33c3 Mon Sep 17 00:00:00 2001 From: Owen Green Date: Thu, 31 Oct 2024 11:31:48 +0000 Subject: [PATCH] SKMeans: Add option for 'Forgy' initialiaation --- include/algorithms/public/SKMeans.hpp | 50 ++++++++++++++++++++------- include/clients/nrt/SKMeansClient.hpp | 12 ++++--- 2 files changed, 44 insertions(+), 18 deletions(-) diff --git a/include/algorithms/public/SKMeans.hpp b/include/algorithms/public/SKMeans.hpp index 0fe41640a..256b6bbd4 100644 --- a/include/algorithms/public/SKMeans.hpp +++ b/include/algorithms/public/SKMeans.hpp @@ -18,6 +18,7 @@ under the European Union’s Horizon 2020 research and innovation programme #include "../../data/TensorTypes.hpp" #include #include +#include #include namespace fluid { @@ -27,8 +28,16 @@ class SKMeans : public KMeans { public: + + enum Initializer { + // Random partition assigns points to random clusters at init + Random_Partition, + //'Forgy' initializes means with k random data points + Forgy + }; + void train(const FluidDataSet& dataset, index k, - index maxIter) + index maxIter, unsigned initialize ) { using namespace Eigen; using namespace _impl; @@ -41,14 +50,14 @@ class SKMeans : public KMeans { mK = k; mDims = dataset.pointSize(); - initMeans(dataPoints); + initMeans(dataPoints, initialize); } while (maxIter-- > 0) { mEmbedding = mMeans.matrix() * dataPointsT; auto assignments = assignClusters(mEmbedding); - if (!changed(assignments)) { break; } + if (mAssignments.rows() && !changed(assignments)) { break; } else mAssignments = assignments; updateEmbedding(); @@ -69,19 +78,34 @@ class SKMeans : public KMeans } private: - - void initMeans(Eigen::MatrixXd& dataPoints) + void initMeans(Eigen::MatrixXd& dataPoints, unsigned initializer) { using namespace Eigen; mMeans = ArrayXXd::Zero(mK, mDims); - mAssignments = - ((0.5 + (0.5 * ArrayXd::Random(dataPoints.rows()))) * (mK - 1)) - .round() - .cast(); - mEmbedding = MatrixXd::Zero(mK, dataPoints.rows()); - for (index i = 0; i < dataPoints.rows(); i++) - mEmbedding(mAssignments(i), i) = 1; - computeMeans(dataPoints); + + switch (initializer) + { + default: + case Initializer::Random_Partition: + mAssignments = + ((0.5 + (0.5 * ArrayXd::Random(dataPoints.rows()))) * (mK - 1)) + .round() + .cast(); + mEmbedding = MatrixXd::Zero(mK, dataPoints.rows()); + for (index i = 0; i < dataPoints.rows(); i++) + mEmbedding(mAssignments(i), i) = 1; + computeMeans(dataPoints); + break; + + case Initializer::Forgy: // means from random selection of data points + ArrayXidx dataIndices = + ArrayXidx::LinSpaced(dataPoints.rows(), 0, dataPoints.rows() - 1); + std::vector samples(mK); + std::sample(dataIndices.begin(), dataIndices.end(), samples.begin(), mK, + std::mt19937{std::random_device{}()}); + mMeans = dataPoints(samples, Eigen::all); + break; + } } void updateEmbedding() diff --git a/include/clients/nrt/SKMeansClient.hpp b/include/clients/nrt/SKMeansClient.hpp index fcefb3d88..4bd7aa005 100644 --- a/include/clients/nrt/SKMeansClient.hpp +++ b/include/clients/nrt/SKMeansClient.hpp @@ -20,13 +20,15 @@ namespace fluid { namespace client { namespace skmeans { -enum { kName, kNumClusters, kThreshold, kMaxIter }; +enum { kName, kNumClusters, kThreshold, kMaxIter, kInit }; constexpr auto SKMeansParams = defineParameters( StringParam>("name", "Name"), LongParam("numClusters", "Number of Clusters", 4, Min(0)), FloatParam("encodingThreshold", "Encoding Threshold", 0.25, Min(0), Max(1)), - LongParam("maxIter", "Max number of Iterations", 100, Min(1))); + LongParam("maxIter", "Max number of Iterations", 100, Min(1)), + EnumParam("initialize","Initialize method",0, "Random Assignment", "Sampled Means") + ); class SKMeansClient : public FluidBaseClient, OfflineIn, @@ -79,7 +81,7 @@ class SKMeansClient : public FluidBaseClient, if (dataSet.size() == 0) return Error(EmptyDataSet); if (k <= 1) return Error(SmallK); if(mTracker.changed(k)) mAlgorithm.clear(); - mAlgorithm.train(dataSet, k, maxIter); + mAlgorithm.train(dataSet, k, maxIter, get()); IndexVector assignments(dataSet.size()); mAlgorithm.getAssignments(assignments); return getCounts(assignments, k); @@ -100,7 +102,7 @@ class SKMeansClient : public FluidBaseClient, if (k <= 1) return Error(SmallK); if (maxIter <= 0) maxIter = 100; if(mTracker.changed(k)) mAlgorithm.clear(); - mAlgorithm.train(dataSet, k, maxIter); + mAlgorithm.train(dataSet, k, maxIter, get()); IndexVector assignments(dataSet.size()); mAlgorithm.getAssignments(assignments); StringVectorView ids = dataSet.getIds(); @@ -171,7 +173,7 @@ class SKMeansClient : public FluidBaseClient, if (k <= 1) return Error(SmallK); if (maxIter <= 0) maxIter = 100; if(mTracker.changed(k)) mAlgorithm.clear(); - mAlgorithm.train(dataSet, k, maxIter); + mAlgorithm.train(dataSet, k, maxIter,get()); IndexVector assignments(dataSet.size()); mAlgorithm.getAssignments(assignments); encode(srcClient, dstClient);