Skip to content

Commit

Permalink
split clusters occuring over large site spans
Browse files Browse the repository at this point in the history
  • Loading branch information
Alan Liddell committed Apr 17, 2019
1 parent d38d875 commit 68ffb3e
Showing 1 changed file with 96 additions and 62 deletions.
158 changes: 96 additions & 62 deletions +jrclust/+sort/assignClusters.m
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,33 @@
%ASSIGNCLUSTERS Given rho-delta information, assign spikes to clusters
sRes = computeCenters(dRes, sRes, hCfg);
sRes.spikeClusters = [];
sRes = doAssignClusters(dRes, sRes, hCfg);

% do initial assignment
hCfg.updateLog('assignClusters', sprintf('Assigning clusters (nClusters: %d)', numel(sRes.clusterCenters)), 1, 0);
sRes = doAssign(dRes, sRes, hCfg);
nClusters = numel(sRes.clusterCenters);

% split clusters by site distance
sRes = splitDist(dRes, sRes, hCfg);

% reassign spikes if we had to split
if nClusters ~= numel(sRes.clusterCenters)
sRes.spikeClusters = [];
sRes = doAssign(dRes, sRes, hCfg);
nClusters = numel(sRes.clusterCenters);
end

% remove small clusters
for iRepeat = 1:1000
sRes = removeSmall(dRes, sRes, hCfg);
if nClusters == numel(sRes.clusterCenters)
break;
end
sRes = doAssign(dRes, sRes, hCfg);
nClusters = numel(sRes.clusterCenters);
end

hCfg.updateLog('assignClusters', sprintf('Finished initial assignment (%d clusters)', nClusters), 0, 1);
end

%% LOCAL FUNCTIONS
Expand Down Expand Up @@ -32,80 +58,88 @@
end
end

function sRes = doAssignClusters(dRes, sRes, hCfg)
nRepeatMax = 1000;
function sRes = doAssign(dRes, sRes, hCfg)
%DOASSIGN Assign spikes to clusters
nSpikes = numel(sRes.ordRho);
nClusters = numel(sRes.clusterCenters);

if isempty(sRes.spikeClusters)
nClustersPrev = 0;
sRes.spikeClusters = zeros([nSpikes, 1], 'int32');
sRes.spikeClusters(sRes.clusterCenters) = 1:nClusters;
end

% one or no center, assign all spikes to one cluster
if numel(sRes.clusterCenters) == 0 || numel(sRes.clusterCenters) == 1
sRes.spikeClusters = ones([nSpikes, 1], 'int32');
sRes.clusterCenters = sRes.ordRho(1);
else
nClustersPrev = sRes.nClusters;
unassigned = sRes.spikeClusters <= 0;
canAssign = sRes.spikeClusters(sRes.spikeNeigh) > 0;
doAssign = unassigned & canAssign;

while any(doAssign)
hCfg.updateLog('assignIter', sprintf('%d/%d spikes unassigned, %d can be assigned', ...
sum(unassigned), nSpikes, sum(doAssign)), 0, 0);
sRes.spikeClusters(doAssign) = sRes.spikeClusters(sRes.spikeNeigh(doAssign));

unassigned = sRes.spikeClusters <= 0;
canAssign = sRes.spikeClusters(sRes.spikeNeigh) > 0;
doAssign = unassigned & canAssign;
end
end

removedClusters = 0;
hCfg.updateLog('assignClusters', sprintf('Assigning clusters (nClusters: %d)', numel(sRes.clusterCenters)), 1, 0);
nClusters = numel(sRes.clusterCenters);

% assign spikes to clusters
for iRepeat = 1:nRepeatMax % repeat 1000 times max
nSpikes = numel(sRes.ordRho);
nClusters = numel(sRes.clusterCenters);
% count spikes in clusters
sRes.spikesByCluster = arrayfun(@(iC) find(sRes.spikeClusters == iC), 1:nClusters, 'UniformOutput', 0);
sRes.unitCount = cellfun(@numel, sRes.spikesByCluster);
sRes.clusterSites = double(arrayfun(@(iC) mode(dRes.spikeSites(sRes.spikesByCluster{iC})), 1:nClusters));
end

if isempty(sRes.spikeClusters)
sRes.spikeClusters = zeros([nSpikes, 1], 'int32');
sRes.spikeClusters(sRes.clusterCenters) = 1:nClusters;
function sRes = splitDist(dRes, sRes, hCfg)
nClusters = numel(sRes.clusterCenters);
for jCluster = 1:nClusters
% get the number of unique sites for this cluster
jSpikes = sRes.spikesByCluster{jCluster};
jSites = dRes.spikeSites(jSpikes);
uniqueSites = unique(jSites);
if numel(unique(jSites)) == 1
continue;
end

% one or no center, assign all spikes to one cluster
if numel(sRes.clusterCenters) == 0 || numel(sRes.clusterCenters) == 1
sRes.spikeClusters = ones([nSpikes, 1], 'int32');
sRes.clusterCenters = sRes.ordRho(1);
else
nNeigh = sRes.spikeNeigh(sRes.ordRho);

for i = 1:10
unassigned = find(sRes.spikeClusters(sRes.ordRho) <= 0);
if isempty(unassigned)
break;
end

unassigned = unassigned(:)';

for j = unassigned
sRes.spikeClusters(sRes.ordRho(j)) = sRes.spikeClusters(nNeigh(j));
end
nUnassigned = sum(sRes.spikeClusters <= 0);

if nUnassigned == 0
break;
end

hCfg.updateLog('assignIter', sprintf('iter %d, %d clusters remain unassigned', i, nUnassigned), 0, 0);
end
sRes.spikeClusters(sRes.spikeClusters <= 0) = 1; %background
end
% order spike sites by density, descending
jRho = sRes.spikeRho(jSpikes);
[~, ordering] = sort(jRho, 'descend');
jSpikes = jSpikes(ordering);
jSites = jSites(ordering);

hCfg.minClusterSize = max(hCfg.minClusterSize, 2*size(dRes.spikeFeatures, 1));
siteOrdering = arrayfun(@(k) find(jSites == k, 1), uniqueSites);
[~, siteOrdering] = sort(siteOrdering);
uniqueSites = uniqueSites(siteOrdering);

% count spikes in clusters
sRes.spikesByCluster = arrayfun(@(iC) find(sRes.spikeClusters == iC), 1:nClusters, 'UniformOutput', 0);
sRes.unitCount = cellfun(@numel, sRes.spikesByCluster);
sRes.clusterSites = double(arrayfun(@(iC) mode(dRes.spikeSites(sRes.spikesByCluster{iC})), 1:nClusters));
siteLocs = hCfg.siteLoc(uniqueSites, :);
siteDists = pdist2(siteLocs, siteLocs);

% remove small clusters
smallClusters = find(sRes.unitCount <= hCfg.minClusterSize);
if isempty(smallClusters) % done!
break;
% find pairwise distances which exceed the merge radius
isFar = siteDists > 2*hCfg.evtMergeRad;
[r, c] = find(isFar);
if isempty(r)
continue;
end

% still here? try again
sRes.clusterCenters(smallClusters) = [];
sRes.spikeClusters = [];
removedClusters = removedClusters + numel(smallClusters);
splitOff = jSpikes(arrayfun(@(k) find(jSites == k, 1), unique(uniqueSites(r(r > c)))));
sRes.clusterCenters = [sRes.clusterCenters; splitOff];
end
end

if iRepeat == nRepeatMax
warning('assignClusters: exceeded nRepeatMax = %d\n', nRepeatMax);
end
end % for
function sRes = removeSmall(dRes, sRes, hCfg)
hCfg.minClusterSize = max(hCfg.minClusterSize, 2*size(dRes.spikeFeatures, 1));

% remove small clusters
smallClusters = sRes.unitCount <= hCfg.minClusterSize;

hCfg.updateLog('assignClusters', ...
sprintf('Finished assigning clusters (was %d, now %d: %d clusters with fewer than %d spikes removed)', ...
nClustersPrev, nClusters, removedClusters, hCfg.minClusterSize), 0, 1);
if any(smallClusters)
sRes.clusterCenters(smallClusters) = [];
sRes.spikeClusters = [];
end
end

0 comments on commit 68ffb3e

Please sign in to comment.