Skip to content

Commit

Permalink
Fix DiscreteMarkovChain logp
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 27, 2023
1 parent e163a6f commit b08610c
Show file tree
Hide file tree
Showing 3 changed files with 349 additions and 330 deletions.
655 changes: 330 additions & 325 deletions notebooks/discrete_markov_chain.ipynb

Large diffs are not rendered by default.

9 changes: 7 additions & 2 deletions pymc_experimental/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from pymc.logprob.abstract import _logprob
from pymc.logprob.basic import logp
from pymc.pytensorf import intX
from pymc.pytensorf import constant_fold, intX
from pymc.util import check_dist_not_registered
from pytensor.graph.basic import Node
from pytensor.tensor import TensorVariable
Expand Down Expand Up @@ -252,11 +252,16 @@ def discrete_mc_logp(op, values, P, steps, init_dist, state_rng, **kwargs):
mc_logprob = logp(init_dist, value[..., :n_lags]).sum(axis=-1)
mc_logprob += pt.log(P[tuple(indexes)]).sum(axis=-1)

# We cannot leave any RV in the logp graph, even if just for an assert
[init_dist_leading_dim] = constant_fold(
[pt.atleast_1d(init_dist).shape[0]], raise_not_constant=False
)

return check_parameters(
mc_logprob,
pt.all(pt.eq(P.shape[-(n_lags + 1) :], P.shape[-1])),
pt.all(pt.allclose(P.sum(axis=-1), 1.0)),
pt.eq(pt.atleast_1d(init_dist).shape[0], n_lags),
pt.eq(init_dist_leading_dim, n_lags),
msg="Last (n_lags + 1) dimensions of P must be square, "
"P must sum to 1 along the last axis, "
"First dimension of init_dist must be n_lags",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,19 @@ def test_logp_with_default_init_dist(self):
P = pt.as_tensor_variable(np.array([[0.1, 0.5, 0.4], [0.3, 0.4, 0.3], [0.9, 0.05, 0.05]]))
x0 = pm.Categorical.dist(p=np.ones(3) / 3)

chain = DiscreteMarkovChain.dist(P=P, init_dist=x0, steps=3)
value = np.array([0, 1, 2])
logp_expected = np.log((1 / 3) * 0.5 * 0.3)

logp = pm.logp(chain, [0, 1, 2]).eval()
assert logp == pytest.approx(np.log((1 / 3) * 0.5 * 0.3), rel=1e-6)
# Test dist directly
chain = DiscreteMarkovChain.dist(P=P, init_dist=x0, steps=3)
logp_eval = pm.logp(chain, value).eval()
np.testing.assert_allclose(logp_eval, logp_expected, rtol=1e-6)

# Test via Model
with pm.Model() as m:
DiscreteMarkovChain("chain", P=P, init_dist=x0, steps=3)
model_logp_eval = m.compile_logp()({"chain": value})
np.testing.assert_allclose(model_logp_eval, logp_expected, rtol=1e-6)

def test_logp_with_user_defined_init_dist(self):
P = pt.as_tensor_variable(np.array([[0.1, 0.5, 0.4], [0.3, 0.4, 0.3], [0.9, 0.05, 0.05]]))
Expand Down

0 comments on commit b08610c

Please sign in to comment.