Skip to content

Commit

Permalink
Optimized and fixed component split
Browse files Browse the repository at this point in the history
  • Loading branch information
VPetukhov committed Oct 23, 2023
1 parent f415663 commit b0c80fa
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 34 deletions.
95 changes: 63 additions & 32 deletions src/processing/bmm_algorithm/bmm_algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,59 +284,90 @@ function get_global_adjacent_classes(data::BmmData)::Dict{Int, Vector{Int}}
return adj_classes_global
end

function build_cell_graph(assignment::Vector{Int}, adjacent_points::Array{Vector{Int}, 1}, mol_ids::Vector{Int}, cell_id::Int; confidence::Union{Vector{Float64}, Nothing}=nothing, confidence_threshold::Float64=0.5)
cur_adj_point = Vector{Int}[]
if confidence === nothing
cur_adj_point = filter.(p -> assignment[p] == cell_id, adjacent_points[mol_ids])
else
mol_ids = mol_ids[confidence[mol_ids] .> confidence_threshold]
cur_adj_point = filter.(p -> (assignment[p] == cell_id) & (confidence[p] .> confidence_threshold), view(adjacent_points, mol_ids))
function split_nonempty_ids(array::AbstractVector{Int})
counts = count_array(array)
splitted = [Vector{Int}(undef, c) for c in counts]
last_id = zeros(Int, maximum(array))

for i in eachindex(array)
fac = array[i]
li = (last_id[fac] += 1)
splitted[fac][li] = i
end

vert_id_per_mol_id = Dict(mi => vi for (vi, mi) in enumerate(mol_ids))
return filter(x -> !isempty(x), splitted)
end

for points in cur_adj_point
for j in 1:length(points)
points[j] = vert_id_per_mol_id[points[j]]
function connected_components(
mol_ids::Vector{Int}, adj_ids::Vector{Vector{Int}}, assignment::Vector{Int}, vert_id_per_mol_id::Vector{Int}
)
# `adj_ids` has to be undirectional here!
(length(mol_ids) > 1) || return ones(Int, length(mol_ids))

cell_id = assignment[mol_ids[1]]
labels = zeros(Int, length(mol_ids))

queue = Int[]
for u in mol_ids
ui = vert_id_per_mol_id[u]
(labels[ui] == 0) || continue
labels[ui] = ui
empty!(queue)
push!(queue, u)

while !isempty(queue)
src = popfirst!(queue)
for vertex in adj_ids[src]
(assignment[vertex] == cell_id) || continue

vi = vert_id_per_mol_id[vertex]
if labels[vi] == 0
push!(queue, vertex)
labels[vi] = ui
end
end
end
end

neg = Graphs.SimpleGraphs.cleanupedges!(cur_adj_point)
return Graphs.SimpleGraph(neg, cur_adj_point), mol_ids
return labels
end

function get_connected_components_per_label(assignment::Vector{Int}, adjacent_points::Array{Vector{Int}, 1}, min_molecules_per_cell::Int; kwargs...)
mol_ids_per_cell = split(1:length(assignment), assignment; drop_zero=true)
real_cell_ids = findall(length.(mol_ids_per_cell) .>= min_molecules_per_cell)
graph_per_cell = map(real_cell_ids) do ci
@spawn build_cell_graph(assignment, adjacent_points, mol_ids_per_cell[ci], ci; kwargs...)[1]
function get_connected_components_per_label(assignment::Vector{Int}, adj_ids::Vector{Vector{Int}})
mol_ids_per_cell = split_ids(assignment; drop_zero=true);
vert_id_per_mol_id = Vector{Int}(undef, length(assignment))

for mol_ids in mol_ids_per_cell
for (i,a) in enumerate(mol_ids)
vert_id_per_mol_id[a] = i
end
end

graph_per_cell = fetch.(graph_per_cell)
conn_comps = map(graph_per_cell) do g
@spawn Graphs.connected_components(g)
cc_per_cell = map(mol_ids_per_cell) do mids
!isempty(mids) || return Vector{Int}[]

@spawn split_nonempty_ids(
connected_components(mids, adj_ids, assignment, vert_id_per_mol_id)
)
end

conn_comps = fetch.(conn_comps)
cc_per_cell = fetch.(cc_per_cell)

return conn_comps, real_cell_ids, mol_ids_per_cell
return cc_per_cell, mol_ids_per_cell
end

function split_cells_by_connected_components!(data::BmmData; add_new_components::Bool, min_molecules_per_cell::Int)
conn_comps_per_cell, real_cell_ids, mol_ids_per_cell = get_connected_components_per_label(
data.assignment, data.adj_list.ids, min_molecules_per_cell; confidence=data.confidence
)
!isempty(real_cell_ids) || return
cc_per_cell, mol_ids_per_cell = get_connected_components_per_label(data.assignment, data.adj_list.ids)
!isempty(cc_per_cell) || return

new_comp_id = 0
if add_new_components
new_comp_id = length(data.components)
n_comps_per_cell = length.(conn_comps_per_cell)
n_comps_per_cell = length.(cc_per_cell)
n_new_comps = sum(n_comps_per_cell) - sum(n_comps_per_cell .> 0)
append_empty_components!(data, n_new_comps)
end

for (cell_id, conn_comps) in zip(real_cell_ids, conn_comps_per_cell)
for (cell_id, conn_comps) in enumerate(cc_per_cell)
(length(conn_comps) > 1) || continue

largest_cc_id = findmax(length.(conn_comps))[2]
Expand All @@ -354,10 +385,10 @@ function split_cells_by_connected_components!(data::BmmData; add_new_components:
end
end

function bmm!(data::BmmData; min_molecules_per_cell::Int, n_iters::Int=500,
function bmm!(data::BmmData; min_molecules_per_cell::Int=2, n_iters::Int=500,
new_component_frac::Float64=0.3, new_component_weight::Float64=0.2,
assignment_history_depth::Int=0, verbose::Union{Progress, Bool}=true,
component_split_step::Int=5, refine::Bool=true)
component_split_step::Int=3, refine::Bool=true)
progress = isa(verbose, Progress) ? verbose : (verbose ? Progress(n_iters) : nothing)

if (assignment_history_depth > 0) && !(:assignment_history in keys(data.tracer))
Expand All @@ -378,7 +409,7 @@ function bmm!(data::BmmData; min_molecules_per_cell::Int, n_iters::Int=500,

if (i % component_split_step == 0) || (i == n_iters)
# TODO: Slowest place
split_cells_by_connected_components!(data; add_new_components=(new_component_frac > 1e-10), min_molecules_per_cell=(i == n_iters ? 0 : min_molecules_per_cell))
split_cells_by_connected_components!(data; add_new_components=(new_component_frac > 1e-10))
end

drop_unused_components!(data)
Expand Down
11 changes: 9 additions & 2 deletions src/utils/general.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
count_array(values::VT where VT<: AbstractVector{<:Integer}, args...; max_value::Union{<:Integer, Nothing}=nothing, kwargs...) =
count_array!(zeros(Int, max_value !== nothing ? max_value : maximum(values)), values, args...; erase_counts=false, kwargs...)
count_array!(
zeros(Int, max_value !== nothing ? max_value : (isempty(values) ? 0 : maximum(values))),
values, args...;
erase_counts=false, kwargs...
)

function count_array!(counts::VT1 where VT1 <: AbstractVector{<:Integer}, values::VT2 where VT2 <: AbstractVector{<:Integer}; drop_zero::Bool=false, erase_counts::Bool=true)
function count_array!(
counts::VT1 where VT1 <: AbstractVector{<:Integer}, values::VT2 where VT2 <: AbstractVector{<:Integer};
drop_zero::Bool=false, erase_counts::Bool=true
)
if erase_counts
counts .= 0
end
Expand Down

0 comments on commit b0c80fa

Please sign in to comment.