Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into prior_optctxt
Browse files Browse the repository at this point in the history
  • Loading branch information
mhauru committed Jun 5, 2024
2 parents d39fe6f + ebc36af commit c4f62ad
Show file tree
Hide file tree
Showing 9 changed files with 1,054 additions and 567 deletions.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down Expand Up @@ -68,6 +70,8 @@ LogDensityProblems = "2"
LogDensityProblemsAD = "1.7.0"
MCMCChains = "5, 6"
NamedArrays = "0.9, 0.10"
Optimization = "3"
OptimizationOptimJL = "0.1, 0.2, 0.3"
OrderedCollections = "1"
Optim = "1"
Reexport = "0.2, 1"
Expand Down
193 changes: 67 additions & 126 deletions ext/TuringOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,105 +2,22 @@ module TuringOptimExt

if isdefined(Base, :get_extension)
import Turing
import Turing: Distributions, DynamicPPL, ForwardDiff, NamedArrays, Printf, Accessors, Statistics, StatsAPI, StatsBase
import Turing:
DynamicPPL,
NamedArrays,
Accessors,
Optimisation
import Optim
else
import ..Turing
import ..Turing: Distributions, DynamicPPL, ForwardDiff, NamedArrays, Printf, Accessors, Statistics, StatsAPI, StatsBase
import ..Turing:
DynamicPPL,
NamedArrays,
Accessors,
Optimisation
import ..Optim
end

"""
ModeResult{
V<:NamedArrays.NamedArray,
M<:NamedArrays.NamedArray,
O<:Optim.MultivariateOptimizationResults,
S<:NamedArrays.NamedArray
}
A wrapper struct to store various results from a MAP or MLE estimation.
"""
struct ModeResult{
V<:NamedArrays.NamedArray,
O<:Optim.MultivariateOptimizationResults,
M<:Turing.OptimLogDensity
} <: StatsBase.StatisticalModel
"A vector with the resulting point estimates."
values::V
"The stored Optim.jl results."
optim_result::O
"The final log likelihood or log joint, depending on whether `MAP` or `MLE` was run."
lp::Float64
"The evaluation function used to calculate the output."
f::M
end
#############################
# Various StatsBase methods #
#############################



function Base.show(io::IO, ::MIME"text/plain", m::ModeResult)
print(io, "ModeResult with maximized lp of ")
Printf.@printf(io, "%.2f", m.lp)
println(io)
show(io, m.values)
end

function Base.show(io::IO, m::ModeResult)
show(io, m.values.array)
end

function StatsBase.coeftable(m::ModeResult; level::Real=0.95)
# Get columns for coeftable.
terms = string.(StatsBase.coefnames(m))
estimates = m.values.array[:, 1]
stderrors = StatsBase.stderror(m)
zscore = estimates ./ stderrors
p = map(z -> StatsAPI.pvalue(Distributions.Normal(), z; tail=:both), zscore)

# Confidence interval (CI)
q = Statistics.quantile(Distributions.Normal(), (1 + level) / 2)
ci_low = estimates .- q .* stderrors
ci_high = estimates .+ q .* stderrors

level_ = 100*level
level_percentage = isinteger(level_) ? Int(level_) : level_

StatsBase.CoefTable(
[estimates, stderrors, zscore, p, ci_low, ci_high],
["Coef.", "Std. Error", "z", "Pr(>|z|)", "Lower $(level_percentage)%", "Upper $(level_percentage)%"],
terms)
end

function StatsBase.informationmatrix(m::ModeResult; hessian_function=ForwardDiff.hessian, kwargs...)
# Calculate Hessian and information matrix.

# Convert the values to their unconstrained states to make sure the
# Hessian is computed with respect to the untransformed parameters.
linked = DynamicPPL.istrans(m.f.varinfo)
if linked
m = Accessors.@set m.f.varinfo = DynamicPPL.invlink!!(m.f.varinfo, m.f.model)
end

# Calculate the Hessian, which is the information matrix because the negative of the log likelihood was optimized
varnames = StatsBase.coefnames(m)
info = hessian_function(m.f, m.values.array[:, 1])

# Link it back if we invlinked it.
if linked
m = Accessors.@set m.f.varinfo = DynamicPPL.link!!(m.f.varinfo, m.f.model)
end

return NamedArrays.NamedArray(info, (varnames, varnames))
end

StatsBase.coef(m::ModeResult) = m.values
StatsBase.coefnames(m::ModeResult) = names(m.values)[1]
StatsBase.params(m::ModeResult) = StatsBase.coefnames(m)
StatsBase.vcov(m::ModeResult) = inv(StatsBase.informationmatrix(m))
StatsBase.loglikelihood(m::ModeResult) = m.lp

####################
# Optim.jl methods #
####################
Expand All @@ -125,26 +42,41 @@ mle = optimize(model, MLE())
mle = optimize(model, MLE(), NelderMead())
```
"""
function Optim.optimize(model::DynamicPPL.Model, ::Turing.MLE, options::Optim.Options=Optim.Options(); kwargs...)
ctx = Turing.OptimizationContext(DynamicPPL.LikelihoodContext())
f = Turing.OptimLogDensity(model, ctx)
function Optim.optimize(
model::DynamicPPL.Model, ::Optimisation.MLE, options::Optim.Options=Optim.Options();
kwargs...
)
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
f = Optimisation.OptimLogDensity(model, ctx)
init_vals = DynamicPPL.getparams(f)
optimizer = Optim.LBFGS()
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
end
function Optim.optimize(model::DynamicPPL.Model, ::Turing.MLE, init_vals::AbstractArray, options::Optim.Options=Optim.Options(); kwargs...)
function Optim.optimize(
model::DynamicPPL.Model,
::Optimisation.MLE,
init_vals::AbstractArray,
options::Optim.Options=Optim.Options();
kwargs...
)
optimizer = Optim.LBFGS()
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
end
function Optim.optimize(model::DynamicPPL.Model, ::Turing.MLE, optimizer::Optim.AbstractOptimizer, options::Optim.Options=Optim.Options(); kwargs...)
ctx = Turing.OptimizationContext(DynamicPPL.LikelihoodContext())
f = Turing.OptimLogDensity(model, ctx)
function Optim.optimize(
model::DynamicPPL.Model,
::Optimisation.MLE,
optimizer::Optim.AbstractOptimizer,
options::Optim.Options=Optim.Options();
kwargs...
)
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
f = Optimisation.OptimLogDensity(model, ctx)
init_vals = DynamicPPL.getparams(f)
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
end
function Optim.optimize(
model::DynamicPPL.Model,
::Turing.MLE,
::Optimisation.MLE,
init_vals::AbstractArray,
optimizer::Optim.AbstractOptimizer,
options::Optim.Options=Optim.Options();
Expand All @@ -154,8 +86,8 @@ function Optim.optimize(
end

function _mle_optimize(model::DynamicPPL.Model, args...; kwargs...)
ctx = Turing.OptimizationContext(DynamicPPL.LikelihoodContext())
return _optimize(model, Turing.OptimLogDensity(model, ctx), args...; kwargs...)
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
return _optimize(model, Optimisation.OptimLogDensity(model, ctx), args...; kwargs...)
end

"""
Expand All @@ -178,26 +110,41 @@ map_est = optimize(model, MAP())
map_est = optimize(model, MAP(), NelderMead())
```
"""
function Optim.optimize(model::DynamicPPL.Model, ::Turing.MAP, options::Optim.Options=Optim.Options(); kwargs...)
ctx = Turing.OptimizationContext(DynamicPPL.DefaultContext())
f = Turing.OptimLogDensity(model, ctx)
function Optim.optimize(
model::DynamicPPL.Model, ::Optimisation.MAP, options::Optim.Options=Optim.Options();
kwargs...
)
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
f = Optimisation.OptimLogDensity(model, ctx)
init_vals = DynamicPPL.getparams(f)
optimizer = Optim.LBFGS()
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
end
function Optim.optimize(model::DynamicPPL.Model, ::Turing.MAP, init_vals::AbstractArray, options::Optim.Options=Optim.Options(); kwargs...)
function Optim.optimize(
model::DynamicPPL.Model,
::Optimisation.MAP,
init_vals::AbstractArray,
options::Optim.Options=Optim.Options();
kwargs...
)
optimizer = Optim.LBFGS()
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
end
function Optim.optimize(model::DynamicPPL.Model, ::Turing.MAP, optimizer::Optim.AbstractOptimizer, options::Optim.Options=Optim.Options(); kwargs...)
ctx = Turing.OptimizationContext(DynamicPPL.DefaultContext())
f = Turing.OptimLogDensity(model, ctx)
function Optim.optimize(
model::DynamicPPL.Model,
::Optimisation.MAP,
optimizer::Optim.AbstractOptimizer,
options::Optim.Options=Optim.Options();
kwargs...
)
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
f = Optimisation.OptimLogDensity(model, ctx)
init_vals = DynamicPPL.getparams(f)
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
end
function Optim.optimize(
model::DynamicPPL.Model,
::Turing.MAP,
::Optimisation.MAP,
init_vals::AbstractArray,
optimizer::Optim.AbstractOptimizer,
options::Optim.Options=Optim.Options();
Expand All @@ -207,8 +154,8 @@ function Optim.optimize(
end

function _map_optimize(model::DynamicPPL.Model, args...; kwargs...)
ctx = Turing.OptimizationContext(DynamicPPL.DefaultContext())
return _optimize(model, Turing.OptimLogDensity(model, ctx), args...; kwargs...)
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
return _optimize(model, Optimisation.OptimLogDensity(model, ctx), args...; kwargs...)
end

"""
Expand All @@ -218,7 +165,7 @@ Estimate a mode, i.e., compute a MLE or MAP estimate.
"""
function _optimize(
model::DynamicPPL.Model,
f::Turing.OptimLogDensity,
f::Optimisation.OptimLogDensity,
init_vals::AbstractArray=DynamicPPL.getparams(f),
optimizer::Optim.AbstractOptimizer=Optim.LBFGS(),
options::Optim.Options=Optim.Options(),
Expand All @@ -236,25 +183,19 @@ function _optimize(

# Warn the user if the optimization did not converge.
if !Optim.converged(M)
@warn "Optimization did not converge! You may need to correct your model or adjust the Optim parameters."
@warn """
Optimization did not converge! You may need to correct your model or adjust the
Optim parameters.
"""
end

# Get the VarInfo at the MLE/MAP point, and run the model to ensure
# correct dimensionality.
# Get the optimum in unconstrained space. `getparams` does the invlinking.
f = Accessors.@set f.varinfo = DynamicPPL.unflatten(f.varinfo, M.minimizer)
f = Accessors.@set f.varinfo = DynamicPPL.invlink(f.varinfo, model)
vals = DynamicPPL.getparams(f)
f = Accessors.@set f.varinfo = DynamicPPL.link(f.varinfo, model)

# Make one transition to get the parameter names.
vns_vals_iter = Turing.Inference.getparams(model, f.varinfo)
varnames = map(Symbol first, vns_vals_iter)
vals = map(last, vns_vals_iter)

# Store the parameters and their names in an array.
vmat = NamedArrays.NamedArray(vals, varnames)

return ModeResult(vmat, M, -M.minimum, f)
return Optimisation.ModeResult(vmat, M, -M.minimum, f)
end

end # module
10 changes: 4 additions & 6 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,11 @@ export @model, # modelling

ordered, # Exports from Bijectors

constrained_space, # optimisation interface
maximum_a_posteriori,
maximum_likelihood,
# The MAP and MLE exports are only needed for the Optim.jl interface.
MAP,
MLE,
get_parameter_bounds,
optim_objective,
optim_function,
optim_problem
MLE

if !isdefined(Base, :get_extension)
using Requires
Expand Down
Loading

0 comments on commit c4f62ad

Please sign in to comment.