From 8d8416ac6c7363c6003ee6ea1fbaac26b4fc8dc3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 12 Sep 2023 19:22:52 +0100 Subject: [PATCH] Attach `varname_to_symbol` mapping to `Chains` (#2078) * _params_to_array now returns varnames and values instead of symbols and values * updated other uses of _params_to_array * Update Project.toml * make inclusion of varname_to_symbol mapping in chains optional * Update Project.toml * Update Project.toml --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- Project.toml | 4 ++-- ext/TuringOptimExt.jl | 2 +- src/mcmc/Inference.jl | 23 ++++++++++++++--------- src/mcmc/emcee.jl | 10 +++++----- 4 files changed, 22 insertions(+), 17 deletions(-) diff --git a/Project.toml b/Project.toml index a89e72eec..c0231afc6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.29" +version = "0.29.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -56,7 +56,7 @@ Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.23.15" +DynamicPPL = "0.23.17" EllipticalSliceSampling = "0.5, 1" ForwardDiff = "0.10.3" Libtask = "0.7, 0.8" diff --git a/ext/TuringOptimExt.jl b/ext/TuringOptimExt.jl index a0710893e..eb594929d 100644 --- a/ext/TuringOptimExt.jl +++ b/ext/TuringOptimExt.jl @@ -253,7 +253,7 @@ function _optimize( Turing.Inference.getparams(model, f.varinfo), DynamicPPL.getlogp(f.varinfo) )] - varnames, _ = Turing.Inference._params_to_array(model, ts) + varnames = map(Symbol, first(Turing.Inference._params_to_array(model, ts))) # Store the parameters and their names in an array. vmat = NamedArrays.NamedArray(vals, varnames) diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 7f4fad950..1b09668ec 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -310,18 +310,17 @@ end function _params_to_array(model::DynamicPPL.Model, ts::Vector) - # TODO: Do we really need to use `Symbol` here? - names_set = OrderedSet{Symbol}() + names_set = OrderedSet{VarName}() # Extract the parameter names and values from each transition. dicts = map(ts) do t nms_and_vs = getparams(model, t) - nms = map(Symbol ∘ first, nms_and_vs) + nms = map(first, nms_and_vs) vs = map(last, nms_and_vs) for nm in nms push!(names_set, nm) end # Convert the names and values to a single dictionary. - return Dict(nms[j] => vs[j] for j in 1:length(vs)) + return OrderedDict(zip(nms, vs)) end names = collect(names_set) vals = [get(dicts[i], key, missing) for i in eachindex(dicts), @@ -379,29 +378,35 @@ function AbstractMCMC.bundle_samples( save_state = false, stats = missing, sort_chain = false, + include_varname_to_symbol = true, discard_initial = 0, thinning = 1, kwargs... ) # Convert transitions to array format. # Also retrieve the variable names. - nms, vals = _params_to_array(model, ts) + varnames, vals = _params_to_array(model, ts) + varnames_symbol = map(Symbol, varnames) # Get the values of the extra parameters in each transition. extra_params, extra_values = get_transition_extras(ts) # Extract names & construct param array. - nms = [nms; extra_params] + nms = [varnames_symbol; extra_params] parray = hcat(vals, extra_values) # Get the average or final log evidence, if it exists. le = getlogevidence(ts, spl, state) # Set up the info tuple. + info = NamedTuple() + + if include_varname_to_symbol + info = merge(info, (varname_to_symbol = OrderedDict(zip(varnames, varnames_symbol)),)) + end + if save_state - info = (model = model, sampler = spl, samplerstate = state) - else - info = NamedTuple() + info = merge(info, (model = model, sampler = spl, samplerstate = state)) end # Merge in the timing info, if available diff --git a/src/mcmc/emcee.jl b/src/mcmc/emcee.jl index f89cad955..d41596075 100644 --- a/src/mcmc/emcee.jl +++ b/src/mcmc/emcee.jl @@ -109,7 +109,8 @@ function AbstractMCMC.bundle_samples( params_vec = map(Base.Fix1(_params_to_array, model), samples) # Extract names and values separately. - nms = params_vec[1][1] + varnames = params_vec[1][1] + varnames_symbol = map(Symbol, varnames) vals_vec = [p[2] for p in params_vec] # Get the values of the extra parameters in each transition. @@ -120,7 +121,7 @@ function AbstractMCMC.bundle_samples( extra_values_vec = [e[2] for e in extra_vec] # Extract names & construct param array. - nms = [nms; extra_params] + nms = [varnames_symbol; extra_params] # `hcat` first to ensure we get the right `eltype`. x = hcat(first(vals_vec), first(extra_values_vec)) # Pre-allocate to minimize memory usage. @@ -133,10 +134,9 @@ function AbstractMCMC.bundle_samples( le = getlogevidence(samples, state, spl) # Set up the info tuple. + info = (varname_to_symbol = OrderedDict(zip(varnames, varnames_symbol)),) if save_state - info = (model = model, sampler = spl, samplerstate = state) - else - info = NamedTuple() + info = merge(info, (model = model, sampler = spl, samplerstate = state)) end # Concretize the array before giving it to MCMCChains.