Skip to content

Commit

Permalink
Adapt to major PyMC changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 10, 2024
1 parent 1b09f82 commit 9cfc521
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 13 deletions.
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies:
- xhistogram
- statsmodels
- pip:
- pymc>=5.17.0 # CI was failing to resolve
- pymc>=5.19.1 # CI was failing to resolve
- blackjax
- scikit-learn
- better_optimize>=0.0.10
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies:
- xhistogram
- statsmodels
- pip:
- pymc>=5.17.0 # CI was failing to resolve
- pymc>=5.19.1 # CI was failing to resolve
- blackjax
- scikit-learn
- better_optimize>=0.0.10
18 changes: 15 additions & 3 deletions pymc_experimental/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,20 @@ def discrete_mc_logp(op, values, P, steps, init_dist, state_rng, **kwargs):
class DiscreteMarkovChainGibbsMetropolis(CategoricalGibbsMetropolis):
name = "discrete_markov_chain_gibbs_metropolis"

def __init__(self, vars, proposal="uniform", order="random", model=None):
def __init__(
self,
vars,
proposal="uniform",
order="random",
model=None,
initial_point=None,
compile_kwargs: dict | None = None,
**kwargs,
):
model = pm.modelcontext(model)
vars = get_value_vars_from_user_vars(vars, model)
initial_point = model.initial_point()
if initial_point is None:
initial_point = model.initial_point()

dimcats = []
# The above variable is a list of pairs (aggregate dimension, number
Expand Down Expand Up @@ -332,7 +342,9 @@ def __init__(self, vars, proposal="uniform", order="random", model=None):
self.tune = True

# We bypass CategoryGibbsMetropolis's __init__ to avoid it's specialiazed initialization logic
ArrayStep.__init__(self, vars, [model.compile_logp()])
if compile_kwargs is None:
compile_kwargs = {}
ArrayStep.__init__(self, vars, [model.compile_logp(**compile_kwargs)], **kwargs)

@staticmethod
def competence(var):
Expand Down
4 changes: 2 additions & 2 deletions pymc_experimental/inference/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,14 +391,14 @@ def sample_laplace_posterior(

else:
info = mu.point_map_info
flat_shapes = [np.prod(shape).astype(int) for _, shape, _ in info]
flat_shapes = [size for _, _, size, _ in info]
slices = [
slice(sum(flat_shapes[:i]), sum(flat_shapes[: i + 1])) for i in range(len(flat_shapes))
]

posterior_draws = [
posterior_draws[..., idx].reshape((chains, draws, *shape)).astype(dtype)
for idx, (name, shape, dtype) in zip(slices, info)
for idx, (name, shape, _, dtype) in zip(slices, info)
]

idata = laplace_draws_to_inferencedata(posterior_draws, model)
Expand Down
7 changes: 5 additions & 2 deletions pymc_experimental/inference/smc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,16 @@ def arviz_from_particles(model, particles):
-------
"""
n_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0]
by_varname = {k.name: v.squeeze()[np.newaxis, :] for k, v in zip(model.value_vars, particles)}
by_varname = {
k.name: v.squeeze()[np.newaxis, :].astype(k.dtype)
for k, v in zip(model.value_vars, particles)
}
varnames = [v.name for v in model.value_vars]
with model:
strace = NDArray(name=model.name)
strace.setup(n_particles, 0)
for particle_index in range(0, n_particles):
strace.record(point={k: by_varname[k][0][particle_index] for k in varnames})
strace.record(point={k: np.asarray(by_varname[k][0][particle_index]) for k in varnames})
multitrace = MultiTrace((strace,))
return to_inference_data(multitrace, log_likelihood=False)

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
pymc>=5.17.0
pymc>=5.19.1
scikit-learn
8 changes: 5 additions & 3 deletions tests/distributions/test_discrete_markov_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,16 +225,18 @@ def test_change_size_univariate(self):
def test_mcmc_sampling(self):
with pm.Model(coords={"step": range(100)}) as model:
init_dist = Categorical.dist(p=[0.5, 0.5])
DiscreteMarkovChain(
markov_chain = DiscreteMarkovChain(
"markov_chain",
P=[[0.1, 0.9], [0.1, 0.9]],
init_dist=init_dist,
shape=(100,),
dims="step",
)

step_method = assign_step_methods(model)
assert isinstance(step_method, DiscreteMarkovChainGibbsMetropolis)
_, assigned_step_methods = assign_step_methods(model)
assert assigned_step_methods[DiscreteMarkovChainGibbsMetropolis] == [
model.rvs_to_values[markov_chain]
]

# Sampler needs no tuning
idata = pm.sample(
Expand Down

0 comments on commit 9cfc521

Please sign in to comment.