Skip to content

Commit

Permalink
Removed Dirichlet sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
Viktor Petukhov committed Jun 4, 2024
1 parent 0adc489 commit eb148b9
Show file tree
Hide file tree
Showing 9 changed files with 49 additions and 204 deletions.
2 changes: 0 additions & 2 deletions src/processing/Processing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ include("distributions/MvNormal.jl")
include("distributions/CategoricalSmoothed.jl")

include("models/AdjList.jl")
include("models/DistributionSampler.jl")
include("models/InitialParams.jl")
include("models/Component.jl")
include("models/BmmData.jl")
Expand All @@ -44,7 +43,6 @@ include("data_processing/boundary_estimation.jl")
include("bmm_algorithm/molecule_clustering.jl")
include("bmm_algorithm/compartment_segmentation.jl")
include("bmm_algorithm/tracing.jl")
include("bmm_algorithm/distribution_samplers.jl")
include("bmm_algorithm/history_analysis.jl")
include("bmm_algorithm/bmm_algorithm.jl")

Expand Down
108 changes: 18 additions & 90 deletions src/processing/bmm_algorithm/bmm_algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ function aggregate_adjacent_component_weights!(
empty!(comp_ids)
bg_comp_weight = 0.0

@inbounds @simd for i in 1:length(adjacent_weights)
for i in 1:length(adjacent_weights)
c_point = adjacent_points[i]
c_id = assignment[c_point]
cw = adjacent_weights[i]
Expand All @@ -67,20 +67,14 @@ end

@inline function fill_adjacent_component_weights!(
adj_classes::Vector{Int}, adj_weights::Vector{Float64}, data::BmmData, mol_id::Int;
component_weights::Dict{Int, Float64}, adj_classes_global::Dict{Int, Vector{Int}}
component_weights::Dict{Int, Float64}
)
# Looks like it's impossible to optimize further, even with vectorization. It means that creating vectorized version of expect_dirichlet_spatial makes few sense
bg_comp_weight = aggregate_adjacent_component_weights!(
adj_classes, adj_weights, component_weights, data.assignment,
data.adj_list.ids[mol_id], data.adj_list.weights[mol_id]
)

if mol_id in keys(adj_classes_global)
n1 = length(adj_classes)
append!(adj_classes, adj_classes_global[mol_id])
append!(adj_weights, ones(length(adj_classes) - n1) .* data.real_edge_weight)
end

return bg_comp_weight
end

Expand Down Expand Up @@ -162,13 +156,12 @@ function expect_dirichlet_spatial!(data::BmmData; stochastic::Bool=true)
adj_weights = [Float64[] for _ in 1:n_threads];
denses = [Float64[] for _ in 1:n_threads]

adj_classes_global = get_global_adjacent_classes(data)
assignment = deepcopy(data.assignment)

@threads for i in 1:size(data.x, 1)
ti = Threads.threadid()
bg_comp_weight = fill_adjacent_component_weights!(
adj_classes[ti], adj_weights[ti], data, i; component_weights=component_weights[ti], adj_classes_global
adj_classes[ti], adj_weights[ti], data, i; component_weights=component_weights[ti]
)

expect_density_for_molecule!(denses[ti], data, i; adj_classes=adj_classes[ti], adj_weights=adj_weights[ti], bg_comp_weight)
Expand All @@ -181,16 +174,21 @@ function expect_dirichlet_spatial!(data::BmmData; stochastic::Bool=true)
end
end

function update_prior_probabilities!(components::Array{<:Component, 1}, new_component_weight::Float64)
c_weights = [max(c.n_samples, new_component_weight) for c in components]
prior_probs = rand(Distributions.Dirichlet(c_weights))
function update_prior_probabilities!(components::Vector{<:Component})
# c_weights = [log(max(c.n_samples, new_component_weight) + 1) for c in components]
# t_mmc = 150.0; # TODO: finish the idea
# t_dist = Normal(t_mmc, t_mmc * 3)
# prior_probs = pdf.(t_dist, c_weights)
# prior_probs[c_weights .< t_mmc] .= c_weights[c_weights .< t_mmc] ./ t_mmc .* pdf(t_dist, t_mmc)

for (c, p) in zip(components, prior_probs)
c.prior_probability = p
# t_mmc = (log10(100) + log10(1000)) / 2
# # t_mmc = log10(2 + 1)
# # t_mmc = log10(100 + 1)
# t_dist = Normal(t_mmc, t_mmc / 2)
# prior_probs = [pdf(t_dist, log10(c.n_samples + 1)) for c in components]

for c in components
c.prior_probability = c.n_samples
end
end

Expand Down Expand Up @@ -242,58 +240,8 @@ function noise_composition_density(data::BmmData{T, MvNormalF{M, N}} where {T, M
return 1e-2
end

function noise_position_density(data::BmmData)::Float64
std_vals = data.distribution_sampler.shape_prior.std_values;
return pdf(MultivariateNormal(zeros(length(std_vals)), diagm(0 => std_vals.^2)), 3 .* std_vals)
end

estimate_noise_density_level(data::BmmData) =
noise_position_density(data) * noise_composition_density(data)

append_empty_component!(data::BmmData) =
push!(data.components, sample_distribution!(data))[end]

append_empty_components!(data::BmmData, new_component_frac::Float64) =
append_empty_components!(data, round(Int, new_component_frac * length(data.components)))

function append_empty_components!(data::BmmData, n::Int)
(n > 0) || return

if n == 1
append_empty_component!(data)
return
end

comps = sample_distributions!(data, n)
append!(data.components, comps)
return
end

function get_global_adjacent_classes(data::BmmData)::Dict{Int, Vector{Int}}
adj_classes_global = Dict{Int, Vector{Int}}()
for (cur_id, comp) in enumerate(data.components)
if comp.n_samples > 0
continue
end

nearest_id = knn(data.position_knn_tree, comp.position_params.μ, 1)[1][1]
for t_id in data.adj_list.ids[nearest_id]
if t_id in keys(adj_classes_global)
push!(adj_classes_global[t_id], cur_id)
else
adj_classes_global[t_id] = [cur_id]
end
end

if nearest_id in keys(adj_classes_global)
push!(adj_classes_global[nearest_id], cur_id)
else
adj_classes_global[nearest_id] = [cur_id]
end
end

return adj_classes_global
end
data.noise_position_density * noise_composition_density(data)

function split_nonempty_ids(array::AbstractVector{Int})
counts = count_array(array)
Expand Down Expand Up @@ -366,18 +314,10 @@ function get_connected_components_per_label(assignment::Vector{Int}, adj_ids::Ve
return cc_per_cell, mol_ids_per_cell
end

function split_cells_by_connected_components!(data::BmmData; add_new_components::Bool)
function split_cells_by_connected_components!(data::BmmData)
cc_per_cell, mol_ids_per_cell = get_connected_components_per_label(data.assignment, data.adj_list.ids)
!isempty(cc_per_cell) || return

new_comp_id = 0
if add_new_components
new_comp_id = length(data.components)
n_comps_per_cell = length.(cc_per_cell)
n_new_comps = sum(n_comps_per_cell) - sum(n_comps_per_cell .> 0)
append_empty_components!(data, n_new_comps)
end

for (cell_id, conn_comps) in enumerate(cc_per_cell)
(length(conn_comps) > 1) || continue

Expand All @@ -386,28 +326,18 @@ function split_cells_by_connected_components!(data::BmmData; add_new_components:
(ci != largest_cc_id) || continue

mol_ids = mol_ids_per_cell[cell_id][c_ids]

if add_new_components
new_comp_id += 1
end

data.assignment[mol_ids] .= new_comp_id
data.assignment[mol_ids] .= 0
end
end
end

function bmm!(
data::BmmData; min_molecules_per_cell::Int=2, n_iters::Int=500,
new_component_frac::Float64=0.3, new_component_weight::Float64=0.2,
assignment_history_depth::Int=0, verbose::Union{Progress, Bool}=true,
component_split_step::Int=3, refine::Bool=true,
freeze_composition::Bool=false, freeze_position::Bool=false, freeze_components::Bool=false
)

if freeze_components
new_component_frac = 0.0
end

progress = isa(verbose, Progress) ? verbose : (verbose ? Progress(n_iters) : nothing)

if (assignment_history_depth > 0) && !(:assignment_history in keys(data.tracer))
Expand All @@ -419,16 +349,14 @@ function bmm!(
maximize!(data; freeze_composition, freeze_position)

for i in 1:n_iters
# TODO: second slowest place
append_empty_components!(data, new_component_frac)
update_prior_probabilities!(data.components, new_component_weight)
update_prior_probabilities!(data.components)
update_n_mols_per_segment!(data)

expect_dirichlet_spatial!(data)

if (i % component_split_step == 0) || (i == n_iters)
# TODO: Slowest place
split_cells_by_connected_components!(data; add_new_components=(new_component_frac > 1e-10))
split_cells_by_connected_components!(data)
end

if !freeze_components
Expand Down
63 changes: 0 additions & 63 deletions src/processing/bmm_algorithm/distribution_samplers.jl

This file was deleted.

5 changes: 4 additions & 1 deletion src/processing/data_processing/boundary_estimation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,10 @@ function boundary_polygons(pos_data::Matrix{Float64}, cell_labels::Vector{<:Inte
return Dict(cid => cb for (cid,cb) in zip(cell_names, cell_borders) if !isempty(cb))
end

function boundary_polygons_auto(pos_data::Matrix{Float64}, assignment::Vector{<:Integer}; estimate_per_z::Bool, cell_names::Union{Vector{String}, Nothing}=nothing)
function boundary_polygons_auto(
pos_data::Matrix{Float64}, assignment::Vector{<:Integer};
estimate_per_z::Bool, cell_names::Union{Vector{String}, Nothing}=nothing, verbose::Bool=true
)
verbose && @info "Estimating boundary polygons"

poly_joint = boundary_polygons(pos_data, assignment; cell_names);
Expand Down
12 changes: 7 additions & 5 deletions src/processing/data_processing/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,14 @@ function initialize_bmm_data(
size_prior = initialize_shape_prior(Float64(scale), scale_std; min_molecules_per_cell, is_3d=(:z in propertynames(df_spatial)))
init_params = cell_centers_uniformly(df_spatial, n_cells_init; scale=scale)

components, sampler, assignment = initial_distributions(df_spatial, init_params; size_prior=size_prior)
components, assignment = initial_distributions(df_spatial, init_params; size_prior=size_prior)

std_vals = size_prior.std_values
noise_position_density = pdf(MultivariateNormal(zeros(length(std_vals)), diagm(0 => std_vals.^2)), 3 .* std_vals)

return BmmData(
components, df_spatial, adj_list, assignment, sampler;
prior_seg_confidence, kwargs...
components, df_spatial, adj_list, assignment;
prior_seg_confidence, noise_position_density, kwargs...
)
end

Expand All @@ -163,9 +166,8 @@ function initial_distributions(df_spatial::DataFrame, initial_params::InitialPar
components = [Component(pd, gd, shape_prior=deepcopy(size_prior)) for (pd, gd) in zip(position_distrubutions, gene_distributions)]

gene_sampler = CategoricalSmoothed(ones(Float64, gene_num))
sampler = DistributionSampler(deepcopy(size_prior), gene_smooth)

return components, sampler, initial_params.assignment
return components, initial_params.assignment
end

function initialize_position_params_from_assignment(pos_data::Matrix{Float64}, cell_assignment::Vector{Int})
Expand Down
12 changes: 5 additions & 7 deletions src/processing/models/BmmData.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ mutable struct BmmData{L, CT}

# Distribution-related
components::Array{Component{L, CT}, 1};
distribution_sampler::DistributionSampler{L};
assignment::Vector{Int};
max_component_guid::Int;

noise_position_density::Float64;
noise_density::Float64;

cluster_per_cell::Vector{Int};
Expand All @@ -40,7 +40,6 @@ mutable struct BmmData{L, CT}
# Utils
tracer::Dict{Symbol, Any}
misc::Dict{Symbol, Any}
center_sample_cache::Vector{Int}

# Parameters
prior_seg_confidence::Float64
Expand All @@ -57,11 +56,10 @@ mutable struct BmmData{L, CT}
- `adj_list::AdjList`:
- `adjacent_weights::Array{Array{Float64, 1}, 1}`: edge weights, used for smoothness penalty
- `real_edge_weight::Float64`: weight of an edge for "average" real point
- `distribution_sampler::DistributionSampler`:
- `assignment::Array{Int, 1}`:
"""
function BmmData(components::Array{Component{N, CT}, 1}, x::DataFrame, adj_list::AdjList,
assignment::Vector{Int}, distribution_sampler::DistributionSampler{N}; real_edge_weight::Float64=1.0, k_neighbors::Int=20,
function BmmData(components::Array{Component{N, CT}, 1}, x::DataFrame, adj_list::AdjList, assignment::Vector{Int};
real_edge_weight::Float64=1.0, k_neighbors::Int=20, noise_position_density::Float64=0.0,
cluster_penalty_mult::Float64=0.25, use_gene_smoothing::Bool=true, prior_seg_confidence::Float64=0.5,
min_nuclei_frac::Float64=0.1, mrf_strength::Float64=0.1, na_genes::Vector{Int}=Int[]) where {N, CT}
@assert maximum(assignment) <= length(components)
Expand Down Expand Up @@ -92,10 +90,10 @@ mutable struct BmmData{L, CT}
self = new{N, CT}(
x, p_data, comp_data, confidence(x), cluster_per_molecule, Int[], nuclei_probs,
adj_list, real_edge_weight, position_knn_tree, knn_neighbors,
components, distribution_sampler, assignment, length(components), 0.0,
components, assignment, length(components), noise_position_density, 0.0,
Int[],
Int[], Int[], # prior segmentation info
Dict{Symbol, Any}(), Dict{Symbol, Any}(), Int[],
Dict{Symbol, Any}(), Dict{Symbol, Any}(),
prior_seg_confidence, cluster_penalty_mult, use_gene_smoothing, min_nuclei_frac, mrf_strength
)

Expand Down
Loading

0 comments on commit eb148b9

Please sign in to comment.