From 137f03a1511bbdab50b71bf32827470d476cebf4 Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Sun, 19 Nov 2023 07:48:48 -0600 Subject: [PATCH 1/5] Improve robustness This eliminates the common failures observed in https://github.com/dmetivie/ExpectationMaximization.jl/issues/11#issuecomment-1817926977 --- src/fit_em.jl | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/fit_em.jl b/src/fit_em.jl index 53aecb7..bf89f35 100644 --- a/src/fit_em.jl +++ b/src/fit_em.jl @@ -124,7 +124,15 @@ function E_step!( ) where {T<:AbstractFloat} # evaluate likelihood for each type k for k in eachindex(dists) - LL[:, k] .= log(α[k]) .+ logpdf.(dists[k], y) + logα = log(α[k]) + robust && !isfinite(logα) && continue + distk = dists[k] + for n in eachindex(y) + logp = logpdf(distk, y[n]) + if !robust || isfinite(logp) + LL[n, k] = logα + logp + end + end end robust && replace!(LL, -Inf => nextfloat(-Inf), Inf => log(prevfloat(Inf))) # get posterior of each category @@ -143,12 +151,16 @@ function E_step!( ) # evaluate likelihood for each type k for k in eachindex(dists) - LL[:, k] .= log(α[k]) + logα = log(α[k]) + robust && !isfinite(logα) && continue + distk = dists[k] for n in axes(y, 2) - LL[n, k] += logpdf(dists[k], y[:, n]) + logp = logpdf(distk, y[:, n]) + if !robust || isfinite(logp) + LL[n, k] = logα + logp + end end end - robust && replace!(LL, -Inf => nextfloat(-Inf), Inf => log(prevfloat(Inf))) # get posterior of each category c[:] = logsumexp(LL, dims = 2) γ[:, :] = exp.(LL .- c) From e6988ebc9e96e6fa8b9810f0b94a0da3e2000f35 Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Tue, 21 Nov 2023 05:13:41 -0600 Subject: [PATCH 2/5] hoist the `robust` check --- src/fit_em.jl | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/src/fit_em.jl b/src/fit_em.jl index bf89f35..c20dbb0 100644 --- a/src/fit_em.jl +++ b/src/fit_em.jl @@ -124,14 +124,18 @@ function E_step!( ) where {T<:AbstractFloat} # evaluate likelihood for each type k for k in eachindex(dists) - logα = log(α[k]) - robust && !isfinite(logα) && continue - distk = dists[k] - for n in eachindex(y) - logp = logpdf(distk, y[n]) - if !robust || isfinite(logp) + logα, distk = log(α[k]), dists[k] + if robust + isfinite(logα) || continue + for n in eachindex(y) + logp = logpdf(distk, y[n]) + isfinite(logp) || continue LL[n, k] = logα + logp end + else + for n in eachindex(y) + LL[n, k] = logα + logpdf(distk, y[n]) + end end end robust && replace!(LL, -Inf => nextfloat(-Inf), Inf => log(prevfloat(Inf))) @@ -151,14 +155,18 @@ function E_step!( ) # evaluate likelihood for each type k for k in eachindex(dists) - logα = log(α[k]) - robust && !isfinite(logα) && continue - distk = dists[k] - for n in axes(y, 2) - logp = logpdf(distk, y[:, n]) - if !robust || isfinite(logp) + logα, distk = log(α[k]), dists[k] + if robust + isfinite(logα) || continue + for n in axes(y, 2) + logp = logpdf(distk, y[:, n]) + isfinite(logp) || continue LL[n, k] = logα + logp end + else + for n in axes(y, 2) + LL[n, k] = logα + logpdf(distk, y[:, n]) + end end end # get posterior of each category From 55eca9ed3f3c783342b17088bd9f341c3ea86b08 Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Tue, 28 Nov 2023 19:25:54 -0600 Subject: [PATCH 3/5] Remove `replace` --- src/fit_em.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/fit_em.jl b/src/fit_em.jl index c20dbb0..f26d0ba 100644 --- a/src/fit_em.jl +++ b/src/fit_em.jl @@ -138,7 +138,6 @@ function E_step!( end end end - robust && replace!(LL, -Inf => nextfloat(-Inf), Inf => log(prevfloat(Inf))) # get posterior of each category logsumexp!(c, LL) # c[:] = logsumexp(LL, dims=2) γ[:, :] .= exp.(LL .- c) From fec4d6b727e75ec69dde5184317ae6678710d913 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20M=C3=A9tivier?= <46794064+dmetivie@users.noreply.github.com> Date: Thu, 28 Dec 2023 22:20:08 +0100 Subject: [PATCH 4/5] add dropout test --- test/runtests.jl | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 6aaccd9..9544d98 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -260,3 +260,28 @@ end ẑ = predict(m, y) @test count(ẑ .== z) / N > 0.85 end + +@testset "Test robustness against dropout issue" begin + # See https://github.com/dmetivie/ExpectationMaximization.jl/issues/11 + # In this example, one of the mixture weight goes to zero outputing at iteration 3 an + # ERROR: PosDefException: matrix is not Hermitian; Cholesky factorization failed. + Random.seed!(1234) + + N = 600 + + ctrue = [[-0.3, 1], + [-0.4, 0.7], + [0.4, -0.6]] + X = reduce(hcat, [randn(length(c), N÷3) .+ c for c in ctrue]) + mix_bad_guess = MixtureModel([MvNormal([1.6, -2.4], [100 0.0; 0.0 1]), MvNormal([-1.1, -0.6], 0.01), MvNormal([0.4, 2.4], 1)]) + + fit_mle(mix_bad_guess, X, maxiter = 1) + + try # make sure our test case is problematic after two iterations without robust option + fit_mle(mix_bad_guess, X, maxiter = 2) #triggers error + catch e + @test true + end + # no error thrown (however the EM did converged to some bad local maxima) + mix_mle_bad = fit_mle(mix_bad_guess, X, maxiter = 2000, robust = true) +end \ No newline at end of file From 7098aa4d4ed90bf85b0cf173f66f8a01c2ce0044 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20M=C3=A9tivier?= <46794064+dmetivie@users.noreply.github.com> Date: Thu, 28 Dec 2023 22:35:46 +0100 Subject: [PATCH 5/5] robustify against dropout SEM + test --- src/stochastic_em.jl | 17 ++++++++++++++--- test/runtests.jl | 15 ++++++++++++--- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/stochastic_em.jl b/src/stochastic_em.jl index 4cd4da4..c652c34 100644 --- a/src/stochastic_em.jl +++ b/src/stochastic_em.jl @@ -59,8 +59,13 @@ function fit_mle!( # M-step # using ẑ, maximize (update) the parameters α[:] = length.(cat)/N - dists[:] = [fit_mle(dists[k], y[cat[k]]) for k = 1:K] - + dists[:] = map(1:K) do k + if α[k] > 0 + fit_mle(dists[k], y[cat[k]]) + else + dists[k] + end + end # E-step # evaluate likelihood for each type k E_step!(LL, c, γ, dists, α, y; robust = robust) @@ -133,7 +138,13 @@ function fit_mle!( # M-step # using ẑ, maximize (update) the parameters α[:] = length.(cat)/N - dists[:] = [fit_mle(dists[k], y[:, cat[k]]) for k = 1:K] + dists[:] = map(1:K) do k + if α[k] > 0 + fit_mle(dists[k], y[:, cat[k]]) + else + dists[k] + end + end # E-step # evaluate likelihood for each type k diff --git a/test/runtests.jl b/test/runtests.jl index 9544d98..abcfc8e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -278,10 +278,19 @@ end fit_mle(mix_bad_guess, X, maxiter = 1) try # make sure our test case is problematic after two iterations without robust option - fit_mle(mix_bad_guess, X, maxiter = 2) #triggers error + fit_mle(mix_bad_guess, X, maxiter = 20) #triggers error + @test false catch e @test true end - # no error thrown (however the EM did converged to some bad local maxima) - mix_mle_bad = fit_mle(mix_bad_guess, X, maxiter = 2000, robust = true) + begin + #! no error thrown, however the EM converges to some bad local maxima! + mix_mle_bad = fit_mle(mix_bad_guess, X, maxiter = 2000, robust = true) + @test true + end + begin + #! no error thrown, however the SEM has one mixture component with zero proba (remaining the same at every iteration) + mix_mle_S = fit_mle(mix_bad_guess, X, method = StochasticEM(), maxiter = 2000) + @test true + end end \ No newline at end of file