Skip to content

Commit

Permalink
Separate Gibbs handling (#1500)
Browse files Browse the repository at this point in the history
* Separate Gibbs handling

* Bump version

* Remove comment

* Fix Gibbs step

* Update comment

* Bump version

* Recompute `vi.logp` in Gibbs sampler

* Define `gibbs_rerun`

* Bump version

* Simplify code

* Update src/inference/gibbs.jl

* Revert `gid` related changes
  • Loading branch information
devmotion authored Feb 4, 2021
1 parent 8886d35 commit 1a4b4e5
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 128 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.15.8"
version = "0.15.9"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
39 changes: 15 additions & 24 deletions src/contrib/inference/dynamichmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,17 @@ struct DynamicNUTSState{V<:AbstractVarInfo,C,M,S}
stepsize::S
end

function gibbs_update_state(state::DynamicNUTSState, varinfo::AbstractVarInfo)
return DynamicNUTSState(varinfo, state.cache, state.metric, state.stepsize)
# Implement interface of `Gibbs` sampler
function gibbs_state(
model::Model,
spl::Sampler{<:DynamicNUTS},
state::DynamicNUTSState,
varinfo::AbstractVarInfo,
)
# Update the previous evaluation.
= DynamicHMCLogDensity(model, spl, varinfo)
Q = DynamicHMC.evaluate_ℓ(ℓ, varinfo[spl])
return DynamicNUTSState(varinfo, Q, state.metric, state.stepsize)
end

DynamicPPL.initialsampler(::Sampler{<:DynamicNUTS}) = SampleFromUniform()
Expand Down Expand Up @@ -90,11 +99,6 @@ function DynamicPPL.initialstep(
vi[spl] = Q.q
DynamicPPL.setlogp!(vi, Q.ℓq)

# If a Gibbs component, transform the values back to the constrained space.
if spl.selector.tag !== :default
DynamicPPL.invlink!(vi, spl)
end

# Create first sample and state.
sample = Transition(vi)
state = DynamicNUTSState(vi, Q, steps.H.κ, steps.ϵ)
Expand All @@ -119,28 +123,15 @@ function AbstractMCMC.step(
ℓ,
state.stepsize,
)
Q = if spl.selector.tag !== :default
# When a Gibbs component, transform values to the unconstrained space
# and update the previous evaluation.
DynamicPPL.link!(vi, spl)
DynamicHMC.evaluate_ℓ(ℓ, vi[spl])
else
state.cache
end
newQ, _ = DynamicHMC.mcmc_next_step(steps, Q)
Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache)

# Update the variables.
vi[spl] = newQ.q
DynamicPPL.setlogp!(vi, newQ.ℓq)

# If a Gibbs component, transform the values back to the constrained space.
if spl.selector.tag !== :default
DynamicPPL.invlink!(vi, spl)
end
vi[spl] = Q.q
DynamicPPL.setlogp!(vi, Q.ℓq)

# Create next sample and state.
sample = Transition(vi)
newstate = DynamicNUTSState(vi, newQ, state.metric, state.stepsize)
newstate = DynamicNUTSState(vi, Q, state.metric, state.stepsize)

return sample, newstate
end
8 changes: 3 additions & 5 deletions src/inference/AdvancedSMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,6 @@ struct PG{space,R} <: ParticleInference
resampler::R
end

isgibbscomponent(::PG) = true

"""
PG(n, space...)
PG(n, [resampler = AdvancedPS.ResampleWithESSThreshold(), space = ()])
Expand Down Expand Up @@ -329,18 +327,18 @@ function DynamicPPL.assume(
unset_flag!(vi, vn, "del")
r = rand(rng, dist)
vi[vn] = vectorize(dist, r)
setgid!(vi, spl.selector, vn)
DynamicPPL.setgid!(vi, spl.selector, vn)
setorder!(vi, vn, get_num_produce(vi))
else
updategid!(vi, vn, spl)
DynamicPPL.updategid!(vi, vn, spl)
r = vi[vn]
end
else # vn belongs to other sampler <=> conditionning on vn
if haskey(vi, vn)
r = vi[vn]
else
r = rand(rng, dist)
push!(vi, vn, r, dist, Selector(:invalid))
push!(vi, vn, r, dist, DynamicPPL.Selector(:invalid))
end
lp = logpdf_with_trans(dist, r, istrans(vi, vn))
acclogp!(vi, lp)
Expand Down
14 changes: 6 additions & 8 deletions src/inference/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ module Inference

using ..Core
using ..Utilities
using DynamicPPL: Metadata, _tail, VarInfo, TypedVarInfo,
using DynamicPPL: Metadata, VarInfo, TypedVarInfo,
islinked, invlink!, getlogp, tonamedtuple, VarName, getsym, vectorize,
settrans!, _getvns, getdist, CACHERESET,
settrans!, _getvns, getdist,
Model, Sampler, SampleFromPrior, SampleFromUniform,
Selector, DefaultContext, PriorContext,
LikelihoodContext, MiniBatchContext, set_flag!, unset_flag!, NamedDist, NoDist,
DefaultContext, PriorContext,
LikelihoodContext, set_flag!, unset_flag!,
getspace, inspace
using Distributions, Libtask, Bijectors
using DistributionsAD: VectorOfMultivariate
Expand All @@ -25,8 +25,6 @@ import AdvancedMH; const AMH = AdvancedMH
import AdvancedPS
import BangBang
import ..Core: getchunksize, getADbackend
import DynamicPPL: get_matching_type,
VarName, _getranges, _getindex, getval, _getvns
import EllipticalSliceSampling
import Random
import MCMCChains
Expand Down Expand Up @@ -414,8 +412,8 @@ include("hmc.jl")
include("mh.jl")
include("is.jl")
include("AdvancedSMC.jl")
include("gibbs.jl")
include("gibbs_conditional.jl")
include("gibbs.jl")
include("../contrib/inference/sghmc.jl")
include("emcee.jl")

Expand All @@ -430,7 +428,7 @@ for alg in (:HMC, :HMCDA, :NUTS, :SGLD, :SGHMC)
@eval DynamicPPL.getspace(::$alg{<:Any, space}) where {space} = space
end

function get_matching_type(
function DynamicPPL.get_matching_type(
spl::Sampler{<:Union{PG, SMC}},
vi,
::Type{TV},
Expand Down
7 changes: 0 additions & 7 deletions src/inference/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ struct ESS{space} <: InferenceAlgorithm end
ESS() = ESS{()}()
ESS(space::Symbol) = ESS{(space,)}()

isgibbscomponent(::ESS) = true

# always accept in the first step
function DynamicPPL.initialstep(
rng::AbstractRNG,
Expand Down Expand Up @@ -58,11 +56,6 @@ function AbstractMCMC.step(
# obtain previous sample
f = vi[spl]

# recompute log-likelihood in logp
if spl.selector.tag !== :default
model(rng, vi, spl)
end

# define previous sampler state
# (do not use cache to avoid in-place sampling from prior)
oldstate = EllipticalSliceSampling.ESSState(f, getlogp(vi), nothing)
Expand Down
98 changes: 87 additions & 11 deletions src/inference/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ Determine whether algorithm `alg` is allowed as a Gibbs component.
"""
isgibbscomponent(alg) = false

isgibbscomponent(::ESS) = true
isgibbscomponent(::GibbsConditional) = true
isgibbscomponent(::Hamiltonian) = true
isgibbscomponent(::MH) = true
isgibbscomponent(::PG) = true

"""
Gibbs(algs...)
Expand Down Expand Up @@ -83,11 +88,72 @@ metadata(t::GibbsTransition) = (lp = t.lp,)
DynamicPPL.getlogp(t::GibbsTransition) = t.lp

# extract varinfo object from state
getvarinfo(state) = state.vi
getvarinfo(state::AbstractVarInfo) = state
"""
gibbs_varinfo(model, sampler, state)
Return the variables corresponding to the current `state` of the Gibbs component `sampler`.
"""
gibbs_varinfo(model, sampler, state) = varinfo(state)
varinfo(state) = state.vi
varinfo(state::AbstractVarInfo) = state

"""
gibbs_state(model, sampler, state, varinfo)
Return an updated state, taking into account the variables sampled by other Gibbs components.
# Arguments
- `model`: model targeted by the Gibbs sampler.
- `sampler`: the sampler for this Gibbs component.
- `state`: the state of `sampler` computed in the previous iteration.
- `varinfo`: the variables, including the ones sampled by other Gibbs components.
"""
gibbs_state(model, sampler, state::AbstractVarInfo, varinfo::AbstractVarInfo) = varinfo

# Update state in Gibbs sampling
function gibbs_state(
model::Model,
spl::Sampler{<:Hamiltonian},
state::HMCState,
varinfo::AbstractVarInfo,
)
# Update hamiltonian
θ_old = varinfo[spl]
hamiltonian = get_hamiltonian(model, spl, varinfo, state, length(θ_old))

# update state with new varinfo object
gibbs_update_state(state::AbstractVarInfo, varinfo::AbstractVarInfo) = varinfo
# TODO: Avoid mutation
resize!(state.z.θ, length(θ_old))
state.z.θ .= θ_old
z = state.z

return HMCState(varinfo, state.i, state.traj, hamiltonian, z, state.adaptor)
end

"""
gibbs_rerun(prev_alg, alg)
Check if the model should be rerun to recompute the log density before sampling with the
Gibbs component `alg` and after sampling from Gibbs component `prev_alg`.
By default, the function returns `true`.
"""
gibbs_rerun(prev_alg, alg) = true

# `vi.logp` already contains the log joint probability if the previous sampler
# used a `GibbsConditional` or one of the standard `Hamiltonian` algorithms
gibbs_rerun(::GibbsConditional, ::MH) = false
gibbs_rerun(::Hamiltonian, ::MH) = false

# `vi.logp` already contains the log joint probability if the previous sampler
# used a `GibbsConditional` or a `MH` algorithm
gibbs_rerun(::MH, ::Hamiltonian) = false
gibbs_rerun(::GibbsConditional, ::Hamiltonian) = false

# do not have to recompute `vi.logp` since it is not used in `step`
gibbs_rerun(prev_alg, ::GibbsConditional) = false

# Do not recompute `vi.logp` since it is reset anyway in `step`
gibbs_rerun(prev_alg, ::PG) = false

# Initialize the Gibbs sampler.
function DynamicPPL.initialstep(
Expand All @@ -107,9 +173,8 @@ function DynamicPPL.initialstep(
else
prev_alg = algs[i-1]
end
rerun = !isa(alg, MH) || prev_alg isa PG || prev_alg isa ESS ||
prev_alg isa GibbsConditional
selector = Selector(Symbol(typeof(alg)), rerun)
rerun = gibbs_rerun(prev_alg, alg)
selector = DynamicPPL.Selector(Symbol(typeof(alg)), rerun)
Sampler(alg, model, selector)
end

Expand All @@ -130,10 +195,16 @@ function DynamicPPL.initialstep(

# Compute initial states of the local samplers.
states = map(samplers) do local_spl
# Recompute `vi.logp` if needed.
if local_spl.selector.rerun
model(rng, vi, local_spl)
end

# Compute initial state.
_, state = DynamicPPL.initialstep(rng, model, local_spl, vi; kwargs...)

# update VarInfo object
vi = getvarinfo(state)
# Update `VarInfo` object.
vi = gibbs_varinfo(model, local_spl, state)

return state
end
Expand All @@ -157,14 +228,19 @@ function AbstractMCMC.step(
vi = state.vi
samplers = state.samplers
states = map(samplers, state.states) do _sampler, _state
# Recompute `vi.logp` if needed.
if _sampler.selector.rerun
model(rng, vi, _sampler)
end

# Update state of current sampler with updated `VarInfo` object.
current_state = gibbs_update_state(_state, vi)
current_state = gibbs_state(model, _sampler, _state, vi)

# Step through the local sampler.
_, newstate = AbstractMCMC.step(rng, model, _sampler, current_state; kwargs...)

# Update `VarInfo` object.
vi = getvarinfo(newstate)
vi = gibbs_varinfo(model, _sampler, newstate)

return newstate
end
Expand Down
9 changes: 2 additions & 7 deletions src/inference/gibbs_conditional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,14 @@ end

DynamicPPL.getspace(::GibbsConditional{S}) where {S} = (S,)

isgibbscomponent(::GibbsConditional) = true

function DynamicPPL.initialstep(
rng::AbstractRNG,
model::Model,
spl::Sampler{<:GibbsConditional},
vi::AbstractVarInfo;
kwargs...
)
return AbstractMCMC.step(rng, model, spl, vi; kwargs...)
return nothing, vi
end

function AbstractMCMC.step(
Expand All @@ -78,14 +76,11 @@ function AbstractMCMC.step(
vi::AbstractVarInfo;
kwargs...
)
if spl.selector.rerun # Recompute joint in logp
model(rng, vi)
end

condvals = conditioned(tonamedtuple(vi))
conddist = spl.alg.conditional(condvals)
updated = rand(rng, conddist)
vi[spl] = [updated;] # setindex allows only vectors in this case...
model(rng, vi, SampleFromPrior()) # update log joint probability

return nothing, vi
end
Expand Down
Loading

2 comments on commit 1a4b4e5

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/29327

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.15.9 -m "<description of version>" 1a4b4e5a75d295742097e52e9bdd8ed0b7e5fd65
git push origin v0.15.9

Please sign in to comment.