Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
dmetivie committed Nov 29, 2023
2 parents e4b6a38 + 99f7b24 commit 4f86333
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 21 deletions.
11 changes: 7 additions & 4 deletions src/classic_em.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ The EM algorithm was introduced by A. P. Dempster, N. M. Laird and D. B. Rubin i
struct ClassicEM <: AbstractEM end

"""
fit_mle!(α::AbstractVector, dists::AbstractVector{F} where {F<:Distribution}, y::AbstractVecOrMat, method::ClassicEM; display=:none, maxiter=1000, atol=1e-3, robust=false)
fit_mle!(α::AbstractVector, dists::AbstractVector{F} where {F<:Distribution}, y::AbstractVecOrMat, method::ClassicEM; display=:none, maxiter=1000, atol=1e-3, rtol=nothing, robust=false)
Use the EM algorithm to update the Distribution `dists` and weights `α` composing a mixture distribution.
- `robust = true` will prevent the (log)likelihood to overflow to `-∞` or `∞`.
- `atol` criteria determining the convergence of the algorithm. If the Loglikelihood difference between two iteration `i` and `i+1` is smaller than `atol` i.e. `|ℓ⁽ⁱ⁺¹⁾ - ℓ⁽ⁱ⁾|<atol`, the algorithm stops.
- `atol` criteria determining the convergence of the algorithm. If the Loglikelihood difference between two iteration `i` and `i+1` is smaller than `atol` i.e. `|ℓ⁽ⁱ⁺¹⁾ - ℓ⁽ⁱ⁾|<atol`, the algorithm stops.
- `rtol` relative tolerance for convergence, `|ℓ⁽ⁱ⁺¹⁾ - ℓ⁽ⁱ⁾|<rtol*(|ℓ⁽ⁱ⁺¹⁾| + |ℓ⁽ⁱ⁾|)/2` (does not check if `rtol` is `nothing`)
- `display` value can be `:none`, `:iter`, `:final` to display Loglikelihood evolution at each iterations `:iter` or just the final one `:final`
"""
function fit_mle!(
Expand All @@ -19,6 +20,7 @@ function fit_mle!(
display = :none,
maxiter = 1000,
atol = 1e-3,
rtol = nothing,
robust = false,
)

Expand Down Expand Up @@ -58,7 +60,7 @@ function fit_mle!(
push!(history["logtots"], logtotp)
history["iterations"] += 1

if abs(logtotp - logtot) < atol
if abs(logtotp - logtot) < atol || (rtol !== nothing && abs(logtotp - logtot) < rtol * (abs(logtot) + abs(logtotp)) / 2)
(display in [:iter, :final]) &&
println("EM converged in ", it, " iterations, final loglikelihood = ", logtotp)
history["converged"] = true
Expand Down Expand Up @@ -88,6 +90,7 @@ function fit_mle!(
display = :none,
maxiter = 1000,
atol = 1e-3,
rtol = nothing,
robust = false,
)

Expand Down Expand Up @@ -127,7 +130,7 @@ function fit_mle!(
push!(history["logtots"], logtotp)
history["iterations"] += 1

if abs(logtotp - logtot) < atol
if abs(logtotp - logtot) < atol || (rtol !== nothing && abs(logtotp - logtot) < rtol * (abs(logtot) + abs(logtotp)) / 2)
(display in [:iter, :final]) &&
println("EM converged in ", it, " iterations, final loglikelihood = ", logtotp)
history["converged"] = true
Expand Down
16 changes: 11 additions & 5 deletions src/fit_em.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""
fit_mle(mix::MixtureModel, y::AbstractVecOrMat, weights...; method = ClassicEM(), display=:none, maxiter=1000, atol=1e-3, robust=false, infos=false)
fit_mle(mix::MixtureModel, y::AbstractVecOrMat, weights...; method = ClassicEM(), display=:none, maxiter=1000, atol=1e-3, rtol=nothing, robust=false, infos=false)
Use the an Expectation Maximization (EM) algorithm to maximize the Loglikelihood (fit) the mixture with an i.i.d sample `y`.
The `mix` input is a mixture that is used to initilize the EM algorithm.
- `weights` when provided, it will compute a weighted version of the EM. (Useful for fitting mixture of mixtures)
- `method` determines the algorithm used.
- `infos = true` returns a `Dict` with informations on the algorithm (converged, iteration number, loglikelihood).
- `robust = true` will prevent the (log)likelihood to overflow to `-∞` or `∞`.
- `atol` criteria determining the convergence of the algorithm. If the Loglikelihood difference between two iteration `i` and `i+1` is smaller than `atol` i.e. `|ℓ⁽ⁱ⁺¹⁾ - ℓ⁽ⁱ⁾|<atol`, the algorithm stops.
- `atol` criteria determining the convergence of the algorithm. If the Loglikelihood difference between two iteration `i` and `i+1` is smaller than `atol` i.e. `|ℓ⁽ⁱ⁺¹⁾ - ℓ⁽ⁱ⁾|<atol`, the algorithm stops.
- `rtol` relative tolerance for convergence, `|ℓ⁽ⁱ⁺¹⁾ - ℓ⁽ⁱ⁾|<rtol*(|ℓ⁽ⁱ⁺¹⁾| + |ℓ⁽ⁱ⁾|)/2` (does not check if `rtol` is `nothing`)
- `display` value can be `:none`, `:iter`, `:final` to display Loglikelihood evolution at each iterations `:iter` or just the final one `:final`
"""
function fit_mle(
Expand All @@ -17,6 +18,7 @@ function fit_mle(
display = :none,
maxiter = 1000,
atol = 1e-3,
rtol = nothing,
robust = false,
infos = false,
)
Expand All @@ -35,6 +37,7 @@ function fit_mle(
display = display,
maxiter = maxiter,
atol = atol,
rtol = rtol,
robust = robust,
)
else
Expand All @@ -47,6 +50,7 @@ function fit_mle(
display = display,
maxiter = maxiter,
atol = atol,
rtol = rtol,
robust = robust,
)
end
Expand All @@ -55,10 +59,10 @@ function fit_mle(
end

"""
fit_mle(mix::AbstractArray{<:MixtureModel}, y::AbstractVecOrMat, weights...; method = ClassicEM(), display=:none, maxiter=1000, atol=1e-3, robust=false, infos=false)
fit_mle(mix::AbstractArray{<:MixtureModel}, y::AbstractVecOrMat, weights...; method = ClassicEM(), display=:none, maxiter=1000, atol=1e-3, rtol=nothing, robust=false, infos=false)
Do the same as `fit_mle` for each (initial) mixtures in the mix array. Then it selects the one with the largest loglikelihood.
Warning: It uses try and catch to avoid errors messages in case EM converges toward a singular solution (probably using robust should be enough in most case to avoid errors).
Warning: It uses try and catch to avoid errors messages in case EM converges toward a singular solution (probably using robust should be enough in most case to avoid errors).
"""
function fit_mle(
mix::AbstractArray{<:MixtureModel},
Expand All @@ -68,6 +72,7 @@ function fit_mle(
display = :none,
maxiter = 1000,
atol = 1e-3,
rtol = nothing,
robust = false,
infos = false,
)
Expand All @@ -93,6 +98,7 @@ function fit_mle(
display = display,
maxiter = maxiter,
atol = atol,
rtol = rtol,
robust = robust,
infos = true,
)
Expand Down Expand Up @@ -148,7 +154,7 @@ function E_step!(
γ[:, :] = exp.(LL .- c)
end

# Utilities
# Utilities

size_sample(y::AbstractMatrix) = size(y, 2)
size_sample(y::AbstractVector) = length(y)
Expand Down
25 changes: 15 additions & 10 deletions src/stochastic_em.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Base.@kwdef struct StochasticEM<:AbstractEM
Base.@kwdef struct StochasticEM<:AbstractEM
rng::AbstractRNG = Random.GLOBAL_RNG
end
The Stochastic EM algorithm was introduced by G. Celeux, and J. Diebolt. in 1985 in [*The SEM Algorithm: A probabilistic teacher algorithm derived from the EM algorithm for the mixture problem*](https://cir.nii.ac.jp/crid/1574231874553755008).
Expand All @@ -16,7 +16,8 @@ end
fit_mle!(α::AbstractVector, dists::AbstractVector{F} where {F<:Distribution}, y::AbstractVecOrMat, method::StochasticEM; display=:none, maxiter=1000, atol=1e-3, robust=false)
Use the stochastic EM algorithm to update the Distribution `dists` and weights `α` composing a mixture distribution.
- `robust = true` will prevent the (log)likelihood to overflow to `-∞` or `∞`.
- `atol` criteria determining the convergence of the algorithm. If the Loglikelihood difference between two iteration `i` and `i+1` is smaller than `atol` i.e. `|ℓ⁽ⁱ⁺¹⁾ - ℓ⁽ⁱ⁾|<atol`, the algorithm stops.
- `atol` criteria determining the convergence of the algorithm. If the Loglikelihood difference between two iteration `i` and `i+1` is smaller than `atol` i.e. `|ℓ⁽ⁱ⁺¹⁾ - ℓ⁽ⁱ⁾|<atol`, the algorithm stops.
- `rtol` relative tolerance for convergence, `|ℓ⁽ⁱ⁺¹⁾ - ℓ⁽ⁱ⁾|<rtol*(|ℓ⁽ⁱ⁺¹⁾| + |ℓ⁽ⁱ⁾|)/2` (does not check if `rtol` is `nothing`)
- `display` value can be `:none`, `:iter`, `:final` to display Loglikelihood evolution at each iterations `:iter` or just the final one `:final`
"""
function fit_mle!(
Expand All @@ -27,6 +28,7 @@ function fit_mle!(
display = :none,
maxiter = 1000,
atol = 1e-3,
rtol = nothing,
robust = false,
)

Expand All @@ -53,7 +55,7 @@ function fit_mle!(
# S-step
ẑ[:] .= [rand(method.rng, Categorical(ℙ...)) forin eachrow(γ)]
cat = [findall(ẑ .== k) for k = 1:K]

# M-step
# using ẑ, maximize (update) the parameters
α[:] = length.(cat)/N
Expand All @@ -70,7 +72,7 @@ function fit_mle!(
push!(history["logtots"], logtotp)
history["iterations"] += 1

if abs(logtotp - logtot) < atol
if abs(logtotp - logtot) < atol || (rtol !== nothing && abs(logtotp - logtot) < rtol * (abs(logtot) + abs(logtotp)) / 2)
(display in [:iter, :final]) &&
println("EM converged in ", it, " iterations, final loglikelihood = ", logtotp)
history["converged"] = true
Expand Down Expand Up @@ -100,6 +102,7 @@ function fit_mle!(
display = :none,
maxiter = 1000,
atol = 1e-3,
rtol = nothing,
robust = false,
)

Expand All @@ -126,7 +129,7 @@ function fit_mle!(
# S-step
= [rand(method.rng, Categorical(ℙ...)) forin eachrow(γ)]
cat = [findall(ẑ .== k) for k = 1:K]

# M-step
# using ẑ, maximize (update) the parameters
α[:] = length.(cat)/N
Expand All @@ -143,7 +146,7 @@ function fit_mle!(
push!(history["logtots"], logtotp)
history["iterations"] += 1

if abs(logtotp - logtot) < atol
if abs(logtotp - logtot) < atol || (rtol !== nothing && abs(logtotp - logtot) < rtol * (abs(logtot) + abs(logtotp)) / 2)
(display in [:iter, :final]) &&
println("EM converged in ", it, " iterations, final loglikelihood = ", logtotp)
history["converged"] = true
Expand Down Expand Up @@ -174,6 +177,7 @@ function fit_mle!(
display = :none,
maxiter = 1000,
atol = 1e-3,
rtol = nothing,
robust = false,
)

Expand All @@ -200,7 +204,7 @@ function fit_mle!(
# S-step
= [rand(method.rng, Categorical(ℙ...)) forin eachrow(γ)]
cat = [findall(ẑ .== k) for k = 1:K]

# M-step
# using ẑ, maximize (update) the parameters
α[:] = [length(cat[k])*sum(w[cat[k]]) for k in 1:K]/sum(w)
Expand All @@ -217,7 +221,7 @@ function fit_mle!(
push!(history["logtots"], logtotp)
history["iterations"] += 1

if abs(logtotp - logtot) < atol
if abs(logtotp - logtot) < atol || (rtol !== nothing && abs(logtotp - logtot) < rtol * (abs(logtot) + abs(logtotp)) / 2)
(display in [:iter, :final]) &&
println("EM converged in ", it, " iterations, final loglikelihood = ", logtotp)
history["converged"] = true
Expand Down Expand Up @@ -248,6 +252,7 @@ function fit_mle!(
display = :none,
maxiter = 1000,
atol = 1e-3,
rtol = nothing,
robust = false,
)

Expand All @@ -274,7 +279,7 @@ function fit_mle!(
# S-step
= [rand(method.rng, Categorical(ℙ...)) forin eachrow(γ)]
cat = [findall(ẑ .== k) for k = 1:K]

# M-step
# using ẑ, maximize (update) the parameters
α[:] = [sum(w[cat[k]]) for k in 1:K]/sum(w)
Expand All @@ -291,7 +296,7 @@ function fit_mle!(
push!(history["logtots"], logtotp)
history["iterations"] += 1

if abs(logtotp - logtot) < atol
if abs(logtotp - logtot) < atol || (rtol !== nothing && abs(logtotp - logtot) < rtol * (abs(logtot) + abs(logtotp)) / 2)
(display in [:iter, :final]) &&
println("EM converged in ", it, " iterations, final loglikelihood = ", logtotp)
history["converged"] = true
Expand Down
30 changes: 28 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ using Random
@test isapprox(θ₁, p[1]...; rtol = rtol)
@test isapprox(α, p[2][1]; rtol = rtol)
@test isapprox(θ₂, p[2][2]; rtol = rtol)

# Test rtol
mix_mle2 =
fit_mle(mix_guess, y; display = :none, rtol = 1e-8, atol = 0, robust = false, infos = false)
p = params(mix_mle2)[1]
@test isapprox([β, 1 - β], probs(mix_mle2); rtol = rtol)
@test isapprox(θ₁, p[1]...; rtol = rtol)
@test isapprox(α, p[2][1]; rtol = rtol)
@test isapprox(θ₂, p[2][2]; rtol = rtol)
end

@testset "Stochastic EM Univariate continuous Mixture Exponential + Laplace" begin
Expand Down Expand Up @@ -53,6 +62,23 @@ end
@test isapprox(μ, p[1][1]; rtol = rtol)
@test isapprox(α, p[2][1]; rtol = rtol)
@test isapprox(θ₂, p[2][2]; rtol = rtol)

mix_mle2 = fit_mle(
mix_guess,
y;
display = :none,
atol = 0,
rtol = 1e-6,
robust = false,
infos = false,
method = StochasticEM(),
)
p = params(mix_mle2)[1]
@test isapprox([β, 1 - β], probs(mix_mle2); rtol = rtol)
@test isapprox(θ₁, p[1][2]; rtol = rtol)
@test isapprox(μ, p[1][1]; rtol = rtol)
@test isapprox(α, p[2][1]; rtol = rtol)
@test isapprox(θ₂, p[2][2]; rtol = rtol)
end

@testset "Multivariate Gaussian Mixture" begin
Expand Down Expand Up @@ -143,7 +169,7 @@ end
α = 1 / 2
β = 0.3

rtol = 5e-2 #
rtol = 5e-2 #
d1 = MixtureModel([Normal(θ₁, σ₁), Normal(θ₂, σ₂)], [α, 1 - α])
d2 = Normal(θ₀, σ₀)
mix_true = MixtureModel([d1, d2], [β, 1 - β])
Expand Down Expand Up @@ -186,7 +212,7 @@ end
α = 1 / 2
β = 0.5

rtol = 5e-2 #
rtol = 5e-2 #
d1 = MixtureModel([Normal(θ₁, σ₁), Laplace(θ₂, σ₂)], [α, 1 - α])
d2 = Normal(θ₀, σ₀)
mix_true = MixtureModel([d1, d2], [β, 1 - β])
Expand Down

0 comments on commit 4f86333

Please sign in to comment.