Skip to content

Commit

Permalink
variable naming / destructuring (#2465)
Browse files Browse the repository at this point in the history
* Variable naming, destructuring

* Tuple -> Vec
  • Loading branch information
penelopeysm authored Jan 14, 2025
1 parent c44d81a commit d93a0dd
Showing 1 changed file with 33 additions and 33 deletions.
66 changes: 33 additions & 33 deletions src/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,49 +429,49 @@ recursively on the remaining samplers, until no samplers remain. Return the glob
and a tuple of initial states for all component samplers.
"""
function gibbs_initialstep_recursive(
rng, model, varnames, samplers, vi, states=(); initial_params=nothing, kwargs...
rng, model, varname_vecs, samplers, vi, states=(); initial_params=nothing, kwargs...
)
# End recursion
if isempty(varnames) && isempty(samplers)
if isempty(varname_vecs) && isempty(samplers)
return vi, states
end

varnames_local = first(varnames)
sampler_local = first(samplers)
varnames, varname_vecs_tail... = varname_vecs
sampler, samplers_tail... = samplers

# Get the initial values for this component sampler.
initial_params_local = if initial_params === nothing
nothing
else
DynamicPPL.subset(vi, varnames_local)[:]
DynamicPPL.subset(vi, varnames)[:]

Check warning on line 446 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L446

Added line #L446 was not covered by tests
end

# Construct the conditioned model.
model_local, context_local = make_conditional(model, varnames_local, vi)
conditioned_model, context = make_conditional(model, varnames, vi)

# Take initial step.
_, new_state_local = AbstractMCMC.step(
# Take initial step with the current sampler.
_, new_state = AbstractMCMC.step(
rng,
model_local,
sampler_local;
conditioned_model,
sampler;
# FIXME: This will cause issues if the sampler expects initial params in unconstrained space.
# This is not the case for any samplers in Turing.jl, but will be for external samplers, etc.
initial_params=initial_params_local,
kwargs...,
)
new_vi_local = varinfo(new_state_local)
new_vi_local = varinfo(new_state)
# Merge in any new variables that were introduced during the step, but that
# were not in the domain of the current sampler.
vi = merge(vi, get_global_varinfo(context_local))
vi = merge(vi, get_global_varinfo(context))
# Merge the new values for all the variables sampled by the current sampler.
vi = merge(vi, new_vi_local)

states = (states..., new_state_local)
states = (states..., new_state)
return gibbs_initialstep_recursive(
rng,
model,
varnames[2:end],
samplers[2:end],
varname_vecs_tail,
samplers_tail,
vi,
states;
initial_params=initial_params,
Expand Down Expand Up @@ -624,26 +624,26 @@ function on the tail, until there are no more samplers left.
function gibbs_step_recursive(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
varnames,
varname_vecs,
samplers,
states,
global_vi,
new_states=();
kwargs...,
)
# End recursion.
if isempty(varnames) && isempty(samplers) && isempty(states)
if isempty(varname_vecs) && isempty(samplers) && isempty(states)
return global_vi, new_states
end

varnames_local = first(varnames)
sampler_local = first(samplers)
state_local = first(states)
varnames, varname_vecs_tail... = varname_vecs
sampler, samplers_tail... = samplers
state, states_tail... = states

# Construct the conditional model and the varinfo that this sampler should use.
model_local, context_local = make_conditional(model, varnames_local, global_vi)
varinfo_local = subset(global_vi, varnames_local)
varinfo_local = match_linking!!(varinfo_local, state_local, model)
conditioned_model, context = make_conditional(model, varnames, global_vi)
vi = subset(global_vi, varnames)
vi = match_linking!!(vi, state, model)

# TODO(mhauru) The below may be overkill. If the varnames for this sampler are not
# sampled by other samplers, we don't need to `setparams`, but could rather simply
Expand All @@ -654,27 +654,27 @@ function gibbs_step_recursive(
# going to be a significant expense anyway.
# Set the state of the current sampler, accounting for any changes made by other
# samplers.
state_local = setparams_varinfo!!(
model_local, sampler_local, state_local, varinfo_local
state = setparams_varinfo!!(
conditioned_model, sampler, state, vi
)

# Take a step with the local sampler.
new_state_local = last(
AbstractMCMC.step(rng, model_local, sampler_local, state_local; kwargs...)
new_state = last(
AbstractMCMC.step(rng, conditioned_model, sampler, state; kwargs...)
)

new_vi_local = varinfo(new_state_local)
new_vi_local = varinfo(new_state)
# Merge the latest values for all the variables in the current sampler.
new_global_vi = merge(get_global_varinfo(context_local), new_vi_local)
new_global_vi = merge(get_global_varinfo(context), new_vi_local)
new_global_vi = setlogp!!(new_global_vi, getlogp(new_vi_local))

new_states = (new_states..., new_state_local)
new_states = (new_states..., new_state)
return gibbs_step_recursive(
rng,
model,
varnames[2:end],
samplers[2:end],
states[2:end],
varname_vecs_tail,
samplers_tail,
states_tail,
new_global_vi,
new_states;
kwargs...,
Expand Down

0 comments on commit d93a0dd

Please sign in to comment.