Skip to content

Commit

Permalink
Fixing issue Samples are outside the support for DiscreteUniform dist…
Browse files Browse the repository at this point in the history
…ribution pyro-ppl#1834
  • Loading branch information
Deathn0t committed Jul 24, 2024
1 parent f6eb6ce commit 59de90f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
7 changes: 7 additions & 0 deletions numpyro/infer/hmc_gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,13 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
and site["fn"].has_enumerate_support
and not site["is_observed"]
}
self._support_enumerates = {
name: site["fn"].enumerate_support(False)
for name, site in self._prototype_trace.items()
if site["type"] == "sample"
and site["fn"].has_enumerate_support
and not site["is_observed"]
}
self._gibbs_sites = [
name
for name, site in self._prototype_trace.items()
Expand Down
6 changes: 6 additions & 0 deletions numpyro/infer/mixed_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from jax import grad, jacfwd, lax, random
from jax.flatten_util import ravel_pytree
import jax
import jax.numpy as jnp

from numpyro.infer.hmc import momentum_generator
Expand Down Expand Up @@ -301,6 +302,11 @@ def body_fn(i, vals):
adapt_state=adapt_state,
)

z_discrete = jax.tree.map(
lambda idx, support: support[idx],
z_discrete,
self._support_enumerates,
)
z = {**z_discrete, **hmc_state.z}
return MixedHMCState(z, hmc_state, rng_key, accept_prob)

Expand Down

0 comments on commit 59de90f

Please sign in to comment.