From 1a4b4e5a75d295742097e52e9bdd8ed0b7e5fd65 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 4 Feb 2021 10:42:56 +0100 Subject: [PATCH] Separate Gibbs handling (#1500) * Separate Gibbs handling * Bump version * Remove comment * Fix Gibbs step * Update comment * Bump version * Recompute `vi.logp` in Gibbs sampler * Define `gibbs_rerun` * Bump version * Simplify code * Update src/inference/gibbs.jl * Revert `gid` related changes --- Project.toml | 2 +- src/contrib/inference/dynamichmc.jl | 39 +++++------- src/inference/AdvancedSMC.jl | 8 +-- src/inference/Inference.jl | 14 ++--- src/inference/ess.jl | 7 --- src/inference/gibbs.jl | 98 +++++++++++++++++++++++++---- src/inference/gibbs_conditional.jl | 9 +-- src/inference/hmc.jl | 62 +++--------------- src/inference/mh.jl | 13 +--- test/inference/gibbs_conditional.jl | 4 +- 10 files changed, 128 insertions(+), 128 deletions(-) diff --git a/Project.toml b/Project.toml index a4c74f84b..1e2ca435e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.15.8" +version = "0.15.9" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/contrib/inference/dynamichmc.jl b/src/contrib/inference/dynamichmc.jl index 6f137512d..9a1cadd3a 100644 --- a/src/contrib/inference/dynamichmc.jl +++ b/src/contrib/inference/dynamichmc.jl @@ -56,8 +56,17 @@ struct DynamicNUTSState{V<:AbstractVarInfo,C,M,S} stepsize::S end -function gibbs_update_state(state::DynamicNUTSState, varinfo::AbstractVarInfo) - return DynamicNUTSState(varinfo, state.cache, state.metric, state.stepsize) +# Implement interface of `Gibbs` sampler +function gibbs_state( + model::Model, + spl::Sampler{<:DynamicNUTS}, + state::DynamicNUTSState, + varinfo::AbstractVarInfo, +) + # Update the previous evaluation. + ℓ = DynamicHMCLogDensity(model, spl, varinfo) + Q = DynamicHMC.evaluate_ℓ(ℓ, varinfo[spl]) + return DynamicNUTSState(varinfo, Q, state.metric, state.stepsize) end DynamicPPL.initialsampler(::Sampler{<:DynamicNUTS}) = SampleFromUniform() @@ -90,11 +99,6 @@ function DynamicPPL.initialstep( vi[spl] = Q.q DynamicPPL.setlogp!(vi, Q.ℓq) - # If a Gibbs component, transform the values back to the constrained space. - if spl.selector.tag !== :default - DynamicPPL.invlink!(vi, spl) - end - # Create first sample and state. sample = Transition(vi) state = DynamicNUTSState(vi, Q, steps.H.κ, steps.ϵ) @@ -119,28 +123,15 @@ function AbstractMCMC.step( ℓ, state.stepsize, ) - Q = if spl.selector.tag !== :default - # When a Gibbs component, transform values to the unconstrained space - # and update the previous evaluation. - DynamicPPL.link!(vi, spl) - DynamicHMC.evaluate_ℓ(ℓ, vi[spl]) - else - state.cache - end - newQ, _ = DynamicHMC.mcmc_next_step(steps, Q) + Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache) # Update the variables. - vi[spl] = newQ.q - DynamicPPL.setlogp!(vi, newQ.ℓq) - - # If a Gibbs component, transform the values back to the constrained space. - if spl.selector.tag !== :default - DynamicPPL.invlink!(vi, spl) - end + vi[spl] = Q.q + DynamicPPL.setlogp!(vi, Q.ℓq) # Create next sample and state. sample = Transition(vi) - newstate = DynamicNUTSState(vi, newQ, state.metric, state.stepsize) + newstate = DynamicNUTSState(vi, Q, state.metric, state.stepsize) return sample, newstate end diff --git a/src/inference/AdvancedSMC.jl b/src/inference/AdvancedSMC.jl index db920357c..66cf9c250 100644 --- a/src/inference/AdvancedSMC.jl +++ b/src/inference/AdvancedSMC.jl @@ -177,8 +177,6 @@ struct PG{space,R} <: ParticleInference resampler::R end -isgibbscomponent(::PG) = true - """ PG(n, space...) PG(n, [resampler = AdvancedPS.ResampleWithESSThreshold(), space = ()]) @@ -329,10 +327,10 @@ function DynamicPPL.assume( unset_flag!(vi, vn, "del") r = rand(rng, dist) vi[vn] = vectorize(dist, r) - setgid!(vi, spl.selector, vn) + DynamicPPL.setgid!(vi, spl.selector, vn) setorder!(vi, vn, get_num_produce(vi)) else - updategid!(vi, vn, spl) + DynamicPPL.updategid!(vi, vn, spl) r = vi[vn] end else # vn belongs to other sampler <=> conditionning on vn @@ -340,7 +338,7 @@ function DynamicPPL.assume( r = vi[vn] else r = rand(rng, dist) - push!(vi, vn, r, dist, Selector(:invalid)) + push!(vi, vn, r, dist, DynamicPPL.Selector(:invalid)) end lp = logpdf_with_trans(dist, r, istrans(vi, vn)) acclogp!(vi, lp) diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index baee2ce20..2e2db8ed7 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -2,12 +2,12 @@ module Inference using ..Core using ..Utilities -using DynamicPPL: Metadata, _tail, VarInfo, TypedVarInfo, +using DynamicPPL: Metadata, VarInfo, TypedVarInfo, islinked, invlink!, getlogp, tonamedtuple, VarName, getsym, vectorize, - settrans!, _getvns, getdist, CACHERESET, + settrans!, _getvns, getdist, Model, Sampler, SampleFromPrior, SampleFromUniform, - Selector, DefaultContext, PriorContext, - LikelihoodContext, MiniBatchContext, set_flag!, unset_flag!, NamedDist, NoDist, + DefaultContext, PriorContext, + LikelihoodContext, set_flag!, unset_flag!, getspace, inspace using Distributions, Libtask, Bijectors using DistributionsAD: VectorOfMultivariate @@ -25,8 +25,6 @@ import AdvancedMH; const AMH = AdvancedMH import AdvancedPS import BangBang import ..Core: getchunksize, getADbackend -import DynamicPPL: get_matching_type, - VarName, _getranges, _getindex, getval, _getvns import EllipticalSliceSampling import Random import MCMCChains @@ -414,8 +412,8 @@ include("hmc.jl") include("mh.jl") include("is.jl") include("AdvancedSMC.jl") -include("gibbs.jl") include("gibbs_conditional.jl") +include("gibbs.jl") include("../contrib/inference/sghmc.jl") include("emcee.jl") @@ -430,7 +428,7 @@ for alg in (:HMC, :HMCDA, :NUTS, :SGLD, :SGHMC) @eval DynamicPPL.getspace(::$alg{<:Any, space}) where {space} = space end -function get_matching_type( +function DynamicPPL.get_matching_type( spl::Sampler{<:Union{PG, SMC}}, vi, ::Type{TV}, diff --git a/src/inference/ess.jl b/src/inference/ess.jl index d3c7021cf..ebcfc4a17 100644 --- a/src/inference/ess.jl +++ b/src/inference/ess.jl @@ -25,8 +25,6 @@ struct ESS{space} <: InferenceAlgorithm end ESS() = ESS{()}() ESS(space::Symbol) = ESS{(space,)}() -isgibbscomponent(::ESS) = true - # always accept in the first step function DynamicPPL.initialstep( rng::AbstractRNG, @@ -58,11 +56,6 @@ function AbstractMCMC.step( # obtain previous sample f = vi[spl] - # recompute log-likelihood in logp - if spl.selector.tag !== :default - model(rng, vi, spl) - end - # define previous sampler state # (do not use cache to avoid in-place sampling from prior) oldstate = EllipticalSliceSampling.ESSState(f, getlogp(vi), nothing) diff --git a/src/inference/gibbs.jl b/src/inference/gibbs.jl index e5a6df443..6f1507805 100644 --- a/src/inference/gibbs.jl +++ b/src/inference/gibbs.jl @@ -10,6 +10,11 @@ Determine whether algorithm `alg` is allowed as a Gibbs component. """ isgibbscomponent(alg) = false +isgibbscomponent(::ESS) = true +isgibbscomponent(::GibbsConditional) = true +isgibbscomponent(::Hamiltonian) = true +isgibbscomponent(::MH) = true +isgibbscomponent(::PG) = true """ Gibbs(algs...) @@ -83,11 +88,72 @@ metadata(t::GibbsTransition) = (lp = t.lp,) DynamicPPL.getlogp(t::GibbsTransition) = t.lp # extract varinfo object from state -getvarinfo(state) = state.vi -getvarinfo(state::AbstractVarInfo) = state +""" + gibbs_varinfo(model, sampler, state) + +Return the variables corresponding to the current `state` of the Gibbs component `sampler`. +""" +gibbs_varinfo(model, sampler, state) = varinfo(state) +varinfo(state) = state.vi +varinfo(state::AbstractVarInfo) = state + +""" + gibbs_state(model, sampler, state, varinfo) + +Return an updated state, taking into account the variables sampled by other Gibbs components. + +# Arguments +- `model`: model targeted by the Gibbs sampler. +- `sampler`: the sampler for this Gibbs component. +- `state`: the state of `sampler` computed in the previous iteration. +- `varinfo`: the variables, including the ones sampled by other Gibbs components. +""" +gibbs_state(model, sampler, state::AbstractVarInfo, varinfo::AbstractVarInfo) = varinfo + +# Update state in Gibbs sampling +function gibbs_state( + model::Model, + spl::Sampler{<:Hamiltonian}, + state::HMCState, + varinfo::AbstractVarInfo, +) + # Update hamiltonian + θ_old = varinfo[spl] + hamiltonian = get_hamiltonian(model, spl, varinfo, state, length(θ_old)) -# update state with new varinfo object -gibbs_update_state(state::AbstractVarInfo, varinfo::AbstractVarInfo) = varinfo + # TODO: Avoid mutation + resize!(state.z.θ, length(θ_old)) + state.z.θ .= θ_old + z = state.z + + return HMCState(varinfo, state.i, state.traj, hamiltonian, z, state.adaptor) +end + +""" + gibbs_rerun(prev_alg, alg) + +Check if the model should be rerun to recompute the log density before sampling with the +Gibbs component `alg` and after sampling from Gibbs component `prev_alg`. + +By default, the function returns `true`. +""" +gibbs_rerun(prev_alg, alg) = true + +# `vi.logp` already contains the log joint probability if the previous sampler +# used a `GibbsConditional` or one of the standard `Hamiltonian` algorithms +gibbs_rerun(::GibbsConditional, ::MH) = false +gibbs_rerun(::Hamiltonian, ::MH) = false + +# `vi.logp` already contains the log joint probability if the previous sampler +# used a `GibbsConditional` or a `MH` algorithm +gibbs_rerun(::MH, ::Hamiltonian) = false +gibbs_rerun(::GibbsConditional, ::Hamiltonian) = false + +# do not have to recompute `vi.logp` since it is not used in `step` +gibbs_rerun(prev_alg, ::GibbsConditional) = false + +# Do not recompute `vi.logp` since it is reset anyway in `step` +gibbs_rerun(prev_alg, ::PG) = false # Initialize the Gibbs sampler. function DynamicPPL.initialstep( @@ -107,9 +173,8 @@ function DynamicPPL.initialstep( else prev_alg = algs[i-1] end - rerun = !isa(alg, MH) || prev_alg isa PG || prev_alg isa ESS || - prev_alg isa GibbsConditional - selector = Selector(Symbol(typeof(alg)), rerun) + rerun = gibbs_rerun(prev_alg, alg) + selector = DynamicPPL.Selector(Symbol(typeof(alg)), rerun) Sampler(alg, model, selector) end @@ -130,10 +195,16 @@ function DynamicPPL.initialstep( # Compute initial states of the local samplers. states = map(samplers) do local_spl + # Recompute `vi.logp` if needed. + if local_spl.selector.rerun + model(rng, vi, local_spl) + end + + # Compute initial state. _, state = DynamicPPL.initialstep(rng, model, local_spl, vi; kwargs...) - # update VarInfo object - vi = getvarinfo(state) + # Update `VarInfo` object. + vi = gibbs_varinfo(model, local_spl, state) return state end @@ -157,14 +228,19 @@ function AbstractMCMC.step( vi = state.vi samplers = state.samplers states = map(samplers, state.states) do _sampler, _state + # Recompute `vi.logp` if needed. + if _sampler.selector.rerun + model(rng, vi, _sampler) + end + # Update state of current sampler with updated `VarInfo` object. - current_state = gibbs_update_state(_state, vi) + current_state = gibbs_state(model, _sampler, _state, vi) # Step through the local sampler. _, newstate = AbstractMCMC.step(rng, model, _sampler, current_state; kwargs...) # Update `VarInfo` object. - vi = getvarinfo(newstate) + vi = gibbs_varinfo(model, _sampler, newstate) return newstate end diff --git a/src/inference/gibbs_conditional.jl b/src/inference/gibbs_conditional.jl index 3b7758b1d..f923d5f86 100644 --- a/src/inference/gibbs_conditional.jl +++ b/src/inference/gibbs_conditional.jl @@ -59,8 +59,6 @@ end DynamicPPL.getspace(::GibbsConditional{S}) where {S} = (S,) -isgibbscomponent(::GibbsConditional) = true - function DynamicPPL.initialstep( rng::AbstractRNG, model::Model, @@ -68,7 +66,7 @@ function DynamicPPL.initialstep( vi::AbstractVarInfo; kwargs... ) - return AbstractMCMC.step(rng, model, spl, vi; kwargs...) + return nothing, vi end function AbstractMCMC.step( @@ -78,14 +76,11 @@ function AbstractMCMC.step( vi::AbstractVarInfo; kwargs... ) - if spl.selector.rerun # Recompute joint in logp - model(rng, vi) - end - condvals = conditioned(tonamedtuple(vi)) conddist = spl.alg.conditional(condvals) updated = rand(rng, conddist) vi[spl] = [updated;] # setindex allows only vectors in this case... + model(rng, vi, SampleFromPrior()) # update log joint probability return nothing, vi end diff --git a/src/inference/hmc.jl b/src/inference/hmc.jl index 3e34c4087..8f76aab24 100644 --- a/src/inference/hmc.jl +++ b/src/inference/hmc.jl @@ -17,11 +17,6 @@ struct HMCState{ adaptor::TAdapt end -# TODO: Include recompute Hamiltonian here? -function gibbs_update_state(state::HMCState, varinfo::AbstractVarInfo) - return HMCState(varinfo, state.i, state.traj, state.hamiltonian, state.z, state.adaptor) -end - ########################## # Hamiltonian Transition # ########################## @@ -81,8 +76,6 @@ struct HMC{AD, space, metricT <: AHMC.AbstractMetric} <: StaticHamiltonian{AD} n_leapfrog::Int # leapfrog step number end -isgibbscomponent(::Hamiltonian) = true - HMC(args...; kwargs...) = HMC{ADBackend()}(args...; kwargs...) function HMC{AD}(ϵ::Float64, n_leapfrog::Int, ::Type{metricT}, space::Tuple) where {AD, metricT <: AHMC.AbstractMetric} return HMC{AD, space, metricT}(ϵ, n_leapfrog) @@ -224,11 +217,6 @@ function DynamicPPL.initialstep( transition = HMCTransition(vi, t) state = HMCState(vi, 1, traj, hamiltonian, t.z, adaptor) - # If a Gibbs component, transform the values back to the constrained space. - if spl.selector.tag !== :default - invlink!(vi, spl) - end - return transition, state end @@ -241,40 +229,15 @@ function AbstractMCMC.step( kwargs... ) # Get step size - ϵ = getstepsize(spl, state) - @debug "current ϵ" ϵ - - # Get VarInfo object - vi = state.vi - i = state.i + 1 - - # When a Gibbs component, transform values to the unconstrained space. - if spl.selector.tag !== :default - link!(vi, spl) - model(rng, vi, spl) - end - - # Get position and log density before transition - θ_old = vi[spl] - log_density_old = getlogp(vi) - hamiltonian = if spl.selector.tag === :default - state.hamiltonian - else - get_hamiltonian(model, spl, vi, state, length(θ_old)) - end - - z = if spl.selector.tag === :default - state.z - else - resize!(state.z.θ, length(θ_old)) - state.z.θ .= θ_old - state.z - end + @debug "current ϵ" getstepsize(spl, state) # Compute transition. + hamiltonian = state.hamiltonian + z = state.z t = AHMC.step(rng, hamiltonian, state.traj, z) # Adaptation + i = state.i + 1 if spl.alg isa AdaptiveHamiltonian hamiltonian, traj, _ = AHMC.adapt!(hamiltonian, state.traj, state.adaptor, @@ -283,24 +246,17 @@ function AbstractMCMC.step( traj = state.traj end - # Update `vi` based on acceptance + # Update variables + vi = state.vi if t.stat.is_accept vi[spl] = t.z.θ setlogp!(vi, t.stat.log_density) - else - vi[spl] = θ_old - setlogp!(vi, log_density_old) end # Compute next transition and state. transition = HMCTransition(vi, t) newstate = HMCState(vi, i, traj, hamiltonian, t.z, state.adaptor) - # If a Gibbs component, transform the values back to the constrained space. - if spl.selector.tag !== :default - invlink!(vi, spl) - end - return transition, newstate end @@ -517,7 +473,7 @@ function DynamicPPL.assume( vn::VarName, vi, ) - updategid!(vi, vn, spl) + DynamicPPL.updategid!(vi, vn, spl) r = vi[vn] # acclogp!(vi, logpdf_with_trans(dist, r, istrans(vi, vn))) # r @@ -533,7 +489,7 @@ function DynamicPPL.dot_assume( vi, ) @assert length(dist) == size(var, 1) - updategid!.(Ref(vi), vns, Ref(spl)) + DynamicPPL.updategid!.(Ref(vi), vns, Ref(spl)) r = vi[vns] var .= r return var, sum(logpdf_with_trans(dist, r, istrans(vi, vns[1]))) @@ -546,7 +502,7 @@ function DynamicPPL.dot_assume( var::AbstractArray, vi, ) - updategid!.(Ref(vi), vns, Ref(spl)) + DynamicPPL.updategid!.(Ref(vi), vns, Ref(spl)) r = reshape(vi[vec(vns)], size(var)) var .= r return var, sum(logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) diff --git a/src/inference/mh.jl b/src/inference/mh.jl index 00e135bda..4b9e74930 100644 --- a/src/inference/mh.jl +++ b/src/inference/mh.jl @@ -186,8 +186,6 @@ function MH(space...) return MH{tuple(syms...), typeof(proposals)}(proposals) end -isgibbscomponent(::MH) = true - ##################### # Utility functions # ##################### @@ -413,11 +411,6 @@ function AbstractMCMC.step( vi::AbstractVarInfo; kwargs... ) - # Recompute joint - if spl.selector.rerun - model(rng, vi) - end - # Cases: # 1. A covariance proposal matrix # 2. A bunch of NamedTuples that specify the proposal space @@ -436,7 +429,7 @@ function DynamicPPL.assume( vn::VarName, vi, ) - updategid!(vi, vn, spl) + DynamicPPL.updategid!(vi, vn, spl) r = vi[vn] return r, logpdf_with_trans(dist, r, istrans(vi, vn)) end @@ -452,7 +445,7 @@ function DynamicPPL.dot_assume( @assert dim(dist) == size(var, 1) getvn = i -> VarName(vn, vn.indexing * "[:,$i]") vns = getvn.(1:size(var, 2)) - updategid!.(Ref(vi), vns, Ref(spl)) + DynamicPPL.updategid!.(Ref(vi), vns, Ref(spl)) r = vi[vns] var .= r return var, sum(logpdf_with_trans(dist, r, istrans(vi, vns[1]))) @@ -467,7 +460,7 @@ function DynamicPPL.dot_assume( ) getvn = ind -> VarName(vn, vn.indexing * "[" * join(Tuple(ind), ",") * "]") vns = getvn.(CartesianIndices(var)) - updategid!.(Ref(vi), vns, Ref(spl)) + DynamicPPL.updategid!.(Ref(vi), vns, Ref(spl)) r = reshape(vi[vec(vns)], size(var)) var .= r return var, sum(logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) diff --git a/test/inference/gibbs_conditional.jl b/test/inference/gibbs_conditional.jl index cd5172a70..7a92637fe 100644 --- a/test/inference/gibbs_conditional.jl +++ b/test/inference/gibbs_conditional.jl @@ -51,7 +51,7 @@ include(dir*"/test/test_utils/AllUtils.jl") chain = sample(gdemo_default, sampler1, 10_000) cond_m_mean = mean(cond_m((s = s_posterior_mean,))) check_numerical(chain, [:m, :s], [cond_m_mean, s_posterior_mean]) - @test all(==(s_posterior_mean), chain[:s]) + @test all(==(s_posterior_mean), chain[:s][2:end]) m_posterior_mean = 7/6 sampler2 = Gibbs( @@ -61,7 +61,7 @@ include(dir*"/test/test_utils/AllUtils.jl") chain = sample(gdemo_default, sampler2, 10_000) cond_s_mean = mean(cond_s((m = m_posterior_mean,))) check_numerical(chain, [:m, :s], [m_posterior_mean, cond_s_mean]) - @test all(==(m_posterior_mean), chain[:m]) + @test all(==(m_posterior_mean), chain[:m][2:end]) # and one for both using the conditional sampler3 = Gibbs(GibbsConditional(:m, cond_m), GibbsConditional(:s, cond_s))