Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Utility for unflattening Datasets #27

Open
sethaxen opened this issue Oct 31, 2022 · 0 comments
Open

Utility for unflattening Datasets #27

sethaxen opened this issue Oct 31, 2022 · 0 comments
Labels
enhancement New feature or request

Comments

@sethaxen
Copy link
Member

sethaxen commented Oct 31, 2022

The natural way to represent a draw from a posterior distribution is as a NamedTuple whose keys are parameter names and whose values are the values. The values can be scalars, arrays, or arbitrary Julia objects. Then all draws for a chain are a vector of such NamedTuples, and we may have a vector of chains. When we convert to InferenceData, we would "flatten" until we get numeric arrays. Each element of such an array is a marginal draw, and this is useful for plotting and diagnostics.

Sometimes though users need the unflattened draws; e.g., when interacting with the PPL, one often needs draws in a format produced by the PPL, which will in general not look like a Dataset. In #11 we discuss ideas for not flattening. A simpler alternative is to provide utility functions for "unflattening". Here's an example of such a function:

julia> using DimensionalData, InferenceObjects

julia> function unflatten(f, v, keep_dims=(:chain, :draw))
           dims = Dimensions.otherdims(v, keep_dims)
           isempty(dims) && return v
           keep_dims_actual = Dimensions.otherdims(v, dims)
           dimnums = Dimensions.dimnum(v, dims)
           data_new = dropdims(mapslices(Base.vect  f, parent(v); dims=dimnums); dims=dimnums)
           return DimArray(data_new, keep_dims_actual)
       end;

By passing f=identity, we can handle the case where draws are scalars or arrays of scalars:

julia> x = convert_to_dataset((; x=randn(2, 3, 8, 4)); dims=(x=[:a, :b],)).x
2×3×8×4 DimArray{Float64,4} x with dimensions: 
  Dim{:a} Sampled{Int64} Base.OneTo(2) ForwardOrdered Regular Points,
  Dim{:b} Sampled{Int64} Base.OneTo(3) ForwardOrdered Regular Points,
  Dim{:draw} Sampled{Int64} Base.OneTo(8) ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
[:, :, 1, 1]
     1         2         3
 1   0.844121  1.79069  -0.349435
 2  -0.435955  2.21937   0.102086
[and 31 more slices...]

julia> x_unflat = unflatten(identity, x)
8×4 DimArray{Matrix{Float64},2} with dimensions: 
  Dim{:draw} Sampled{Int64} Base.OneTo(8) ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
      4
 1      [-0.710667 2.17712 1.25004; 1.15662 0.138343 0.511868]
 2      [-1.59593 -0.847627 0.185637; 1.62011 0.733101 -0.82679]
 3      [1.6434 0.188635 0.434926; -2.70953 0.223494 -1.00055]
 4      [-0.753704 -2.25251 0.32903; 1.97774 -0.744595 1.0287]
 5     [0.837521 -0.252849 0.0989726; -1.10382 0.511166 0.566629]
 6      [-1.58429 -0.164573 1.83263; 0.875992 -0.174146 -1.10488]
 7      [-2.21422 -0.398891 -1.26135; 1.27395 -0.150042 0.243492]
 8      [0.789781 0.052268 -1.51552; 0.5554 1.08581 -1.16574]

julia> x_unflat[1]
2×3 Matrix{Float64}:
  0.844121  1.79069  -0.349435
 -0.435955  2.21937   0.102086

Other fs let us handle cases where draws are not array types. For example, here's how we might unflatten a real array representing complex draws:

julia> z = convert_to_dataset((; z=randn(2, 8, 4)); dims=(z=[:reim],), coords=(reim=[:re, :im],)).z
2×8×4 DimArray{Float64,3} z with dimensions: 
  Dim{:reim} Categorical{Symbol} Symbol[re, im] ReverseOrdered,
  Dim{:draw} Sampled{Int64} Base.OneTo(8) ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
[:, :, 1]
        1           2          3          4          5          6          7         8
  :re   0.149895   -0.758902  -0.162169  -1.58568   -1.9113    -0.873895  -1.15336  -0.723117
  :im  -0.0615223  -0.191197   0.552402   0.754498  -0.139014   0.496133   1.69164  -1.05489
[and 3 more slices...]

julia> z_unflat = unflatten(Base.splat(complex), z)
8×4 DimArray{ComplexF64,2} with dimensions: 
  Dim{:draw} Sampled{Int64} Base.OneTo(8) ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
             1                      2                      3                       4
 1   0.149895-0.0615223im   -1.49337+0.0310133im  -0.186236-0.632437im     0.122908-1.8747im
 2  -0.758902-0.191197im   0.0847507+1.8477im      0.699646+0.0940246im   -0.700787+0.589689im
 3  -0.162169+0.552402im   -0.426661-0.215763im     1.24455-0.30482im       0.87671-0.0396714im
 4   -1.58568+0.754498im    -1.08887-0.0911398im   -1.18796+0.0439568im    0.583836+0.226613im
 5    -1.9113-0.139014im    -1.11748+0.521976im   -0.453853-0.668656im     -1.40155+0.216688im
 6  -0.873895+0.496133im    0.471934+0.508555im     -1.1003-0.844055im       2.6073-0.25573im
 7   -1.15336+1.69164im     0.107038+0.070659im    -2.15358-1.19693im    -0.0646238-0.749879im
 8  -0.723117-1.05489im       1.0455-0.601896im   -0.931837+0.621233im     0.789712+0.442579im

By applying this approach to all parameters in a Dataset, we can unflatten everything:

julia> using ArviZExampleData

julia> idata = load_example_data("centered_eight");

julia> post = idata.posterior
Dataset with dimensions: 
  Dim{:draw} Sampled{Int64} Int64[0, 1, , 498, 499] ForwardOrdered Irregular Points,
  Dim{:chain} Sampled{Int64} Int64[0, 1, 2, 3] ForwardOrdered Irregular Points,
  Dim{:school} Categorical{String} String[Choate, Deerfield, , St. Paul's, Mt. Hermon] Unordered
and 3 layers:
  :mu    Float64 dims: Dim{:draw}, Dim{:chain} (500×4)
  :theta Float64 dims: Dim{:school}, Dim{:draw}, Dim{:chain} (8×500×4)
  :tau   Float64 dims: Dim{:draw}, Dim{:chain} (500×4)

with metadata Dict{String, Any} with 6 entries:
  "created_at"                => "2022-10-13T14:37:37.315398"
  "inference_library_version" => "4.2.2"
  "sampling_time"             => 7.48011
  "tuning_steps"              => 1000
  "arviz_version"             => "0.13.0.dev0"
  "inference_library"         => "pymc"

julia> post_new = Dataset(map(v -> unflatten(identity, v), NamedTuple(post)))
Dataset with dimensions: 
  Dim{:draw} Sampled{Int64} Int64[0, 1, , 498, 499] ForwardOrdered Irregular Points,
  Dim{:chain} Sampled{Int64} Int64[0, 1, 2, 3] ForwardOrdered Irregular Points
and 3 layers:
  :mu    Float64 dims: Dim{:draw}, Dim{:chain} (500×4)
  :theta Vector{Float64} dims: Dim{:draw}, Dim{:chain} (500×4)
  :tau   Float64 dims: Dim{:draw}, Dim{:chain} (500×4)


julia> post_new[1]
(mu = 7.871796366146925, theta = [12.320685578094814, 9.905366892588605, 14.9516154956564, 11.011484941973162, 5.5796015919074735, 16.901795293711004, 13.198059333176934, 15.06136583596694], tau = 4.725740062893666)

I propose we add something like this utility to the API to make it easier to use InferenceObjects with PPLs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant