Skip to content

Commit

Permalink
Merge pull request #93 from gibsramen/fix-90
Browse files Browse the repository at this point in the history
Fix #90
  • Loading branch information
gibsramen authored Oct 4, 2023
2 parents 35cf431 + 09f8580 commit 3d7aac5
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 11 deletions.
27 changes: 16 additions & 11 deletions birdman/inference.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import List, Sequence, Union

import arviz as az
Expand All @@ -21,18 +22,25 @@ def fit_to_inference(
if posterior_predictive is not None and posterior_predictive not in dims:
raise KeyError("Must include dimensions for posterior predictive!")

# Required because as of writing, CmdStanVB.stan_variable defaults to
# returning the mean rather than the sample
if isinstance(fit, CmdStanVB):
stan_var_fn = partial(fit.stan_variable, mean=False)
else:
stan_var_fn = fit.stan_variable

das = dict()

for param in params:
data = fit.stan_variable(param)
data = stan_var_fn(param)

_dims = dims[param]
_coords = {k: coords[k] for k in _dims}

das[param] = stan_var_to_da(data, _coords, _dims, chains, draws)

if log_likelihood:
data = fit.stan_variable(log_likelihood)
data = stan_var_fn(log_likelihood)

_dims = dims[log_likelihood]
_coords = {k: coords[k] for k in _dims}
Expand All @@ -43,7 +51,7 @@ def fit_to_inference(
ll_ds = None

if posterior_predictive:
data = fit.stan_variable(posterior_predictive)
data = stan_var_fn(posterior_predictive)

_dims = dims[posterior_predictive]
_coords = {k: coords[k] for k in _dims}
Expand Down Expand Up @@ -84,21 +92,19 @@ def concatenate_inferences(
"""
group_list = []
group_list.append([x.posterior for x in inf_list])
group_list.append([x.sample_stats for x in inf_list])
if "log_likelihood" in inf_list[0].groups():
group_list.append([x.log_likelihood for x in inf_list])
if "posterior_predictive" in inf_list[0].groups():
group_list.append([x.posterior_predictive for x in inf_list])

po_ds = xr.concat(group_list[0], concatenation_name)
ss_ds = xr.concat(group_list[1], concatenation_name)
group_dict = {"posterior": po_ds, "sample_stats": ss_ds}
group_dict = {"posterior": po_ds}

if "log_likelihood" in inf_list[0].groups():
ll_ds = xr.concat(group_list[2], concatenation_name)
ll_ds = xr.concat(group_list[1], concatenation_name)
group_dict["log_likelihood"] = ll_ds
if "posterior_predictive" in inf_list[0].groups():
pp_ds = xr.concat(group_list[3], concatenation_name)
pp_ds = xr.concat(group_list[2], concatenation_name)
group_dict["posterior_predictive"] = pp_ds

all_group_inferences = []
Expand All @@ -114,16 +120,15 @@ def concatenate_inferences(
return az.concat(*all_group_inferences)


# TODO: Fix docstring
def stan_var_to_da(
data: np.ndarray,
coords: dict,
dims: dict,
chains: int,
draws: int
):
"""Convert Stan variable draws to xr.DataArray.
"""
"""Convert Stan variable draws to xr.DataArray."""
data = np.stack(np.split(data, chains))

coords["draw"] = np.arange(draws)
Expand Down
30 changes: 30 additions & 0 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import numpy as np
import pytest

from birdman import inference as mu
from birdman.default_models import NegativeBinomialSingle
from birdman import ModelIterator


class TestToInference:
Expand Down Expand Up @@ -78,3 +81,30 @@ def test_serial_ppll(self, example_model):
nb_data = example_model.fit.stan_variable(v)
nb_data = np.array(np.split(nb_data, 4, axis=0))
np.testing.assert_array_almost_equal(nb_data, inf_data)


@pytest.mark.parametrize("method", ["mcmc", "vi"])
def test_concat(table_biom, metadata, method):
tbl = table_biom
md = metadata

model_iterator = ModelIterator(
table=tbl,
model=NegativeBinomialSingle,
formula="host_common_name",
metadata=md,
)

infs = []
for fname, model in model_iterator:
model.compile_model()
model.fit_model(method, num_draws=100)
infs.append(model.to_inference())

inf_concat = mu.concatenate_inferences(
infs,
coords={"feature": tbl.ids("observation")},
)
exp_feat_ids = tbl.ids("observation")
feat_ids = inf_concat.posterior.coords["feature"].to_numpy()
assert (exp_feat_ids == feat_ids).all()

0 comments on commit 3d7aac5

Please sign in to comment.