Skip to content

Commit

Permalink
Add NestedToMCMCAdapter to enable compatibility with ArviZ and MCMC w…
Browse files Browse the repository at this point in the history
…orkflows (arviz-devs#2391)
  • Loading branch information
SamuelSonoiki committed Dec 5, 2024
1 parent 529d795 commit fa7934f
Showing 1 changed file with 110 additions and 0 deletions.
110 changes: 110 additions & 0 deletions arviz/data/io_numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,96 @@

_log = logging.getLogger(__name__)

class NestedToMCMCAdapter:
"""
Adapter to convert a NestedSampler object into an MCMC-compatible interface.
This class reshapes posterior samples from a NestedSampler into a chain-and-draw
structure expected by MCMC workflows, providing compatibility with downstream
tools like ArviZ for posterior analysis.
Parameters
----------
nested_sampler : numpyro.contrib.nested_sampling.NestedSampler
The NestedSampler object containing posterior samples.
rng_key : jax.random.PRNGKey
The random key used for sampling.
num_samples : int
The total number of posterior samples to draw.
num_chains : int, optional
The number of artificial chains to create for MCMC compatibility (default is 1).
*args : tuple
Additional positional arguments required by the model (e.g., data, labels).
**kwargs : dict
Additional keyword arguments required by the model.
Attributes
----------
samples : dict
Reshaped posterior samples organized by variable name.
thinning : int
Dummy thinning attribute for compatibility with MCMC.
sampler : NestedToMCMCAdapter
Mimics the sampler attribute of an MCMC object.
model : callable
The probabilistic model used in the NestedSampler.
_args : tuple
Positional arguments passed to the model.
_kwargs : dict
Keyword arguments passed to the model.
Methods
-------
get_samples(group_by_chain=True)
Returns posterior samples reshaped by chain or flattened if `group_by_chain` is False.
get_extra_fields(group_by_chain=True)
Provides dummy sampling statistics like accept probabilities, step sizes, and num_steps.
"""
def __init__(self, nested_sampler, rng_key, num_samples, *args, num_chains=1, **kwargs):
self.nested_sampler = nested_sampler
self.rng_key = rng_key
self.num_samples = num_samples
self.num_chains = num_chains
self.samples = self._reshape_samples()
self.thinning = 1
self.sampler = self
self.model = nested_sampler.model
self._args = args
self._kwargs = kwargs

def _reshape_samples(self):
raw_samples = self.nested_sampler.get_samples(self.rng_key, self.num_samples)
samples_per_chain = self.num_samples // self.num_chains
return {
k: np.reshape(v[:samples_per_chain * self.num_chains],
(self.num_chains, samples_per_chain, *v.shape[1:]))
for k, v in raw_samples.items()
}

def get_samples(self, group_by_chain=True):
if group_by_chain:
return self.samples
else:
# Flatten chains into a single dimension
return {k: v.reshape(-1, *v.shape[2:]) for k, v in self.samples.items()}

def get_extra_fields(self, group_by_chain=True):
# Generate dummy fields since NestedSampler does not produce these
n_chains = self.num_chains
n_samples = self.num_samples // self.num_chains

# Create dummy values for extra fields
extra_fields = {
"accept_prob": np.full((n_chains, n_samples), 1.0), # Assume all proposals are accepted
"step_size": np.full((n_chains, n_samples), 0.1), # Dummy step size
"num_steps": np.full((n_chains, n_samples), 10), # Dummy number of steps
}

if not group_by_chain:
# Flatten the chains into a single dimension
extra_fields = {k: v.reshape(-1, *v.shape[2:]) for k, v in extra_fields.items()}

return extra_fields

class NumPyroConverter:
"""Encapsulate NumPyro specific logic."""
Expand All @@ -37,6 +127,10 @@ def __init__(
dims=None,
pred_dims=None,
num_chains=1,
rng_key=None,
num_samples=1000,
data=None,
labels=None,
):
"""Convert NumPyro data into an InferenceData object.
Expand Down Expand Up @@ -68,6 +162,14 @@ def __init__(
import numpyro

self.posterior = posterior
self.rng_key = rng_key
self.num_samples = num_samples

if isinstance(posterior, numpyro.contrib.nested_sampling.NestedSampler):
posterior = NestedToMCMCAdapter(posterior, rng_key, num_samples,
num_chains=num_chains, data=data, labels=labels)
self.posterior = posterior

self.prior = jax.device_get(prior)
self.posterior_predictive = jax.device_get(posterior_predictive)
self.predictions = predictions
Expand Down Expand Up @@ -340,6 +442,10 @@ def from_numpyro(
dims=None,
pred_dims=None,
num_chains=1,
rng_key=None,
num_samples=1000,
data=None,
labels=None,
):
"""Convert NumPyro data into an InferenceData object.
Expand Down Expand Up @@ -383,4 +489,8 @@ def from_numpyro(
dims=dims,
pred_dims=pred_dims,
num_chains=num_chains,
rng_key=rng_key,
num_samples=num_samples,
data=data,
labels=labels,
).to_inference_data()

0 comments on commit fa7934f

Please sign in to comment.