diff --git a/tests/infer/test_valid_models.py b/tests/infer/test_valid_models.py index 86ab157a05..c67b1b2c74 100644 --- a/tests/infer/test_valid_models.py +++ b/tests/infer/test_valid_models.py @@ -2344,6 +2344,58 @@ def guide(): assert_ok(model, guide, elbo) +@pytest.mark.stage("funsor") +@pytest.mark.parametrize("num_particles", [1, 2]) +def test_collapse_diag_normal_plate_normal(num_particles): + T, d = 5, 3 + data = torch.randn((T, d)) + + def model(): + x = pyro.sample("x", dist.Normal(0., 1.)).unsqueeze(-1) + with poutine.collapse(): + with pyro.plate("data", T, dim=-1): + expand_shape = (d,) if num_particles == 1 else (num_particles, 1, d) + y = pyro.sample("y", dist.Normal(x, 1.).expand(expand_shape).to_event(1)) + pyro.sample("z", dist.Normal(y, 1.).to_event(1), obs=data) + + def guide(): + loc = pyro.param("loc", torch.tensor(0.)) + scale = pyro.param("scale", torch.tensor(1.), constraint=constraints.positive) + pyro.sample("x", dist.Normal(loc, scale)) + + elbo = Trace_ELBO(num_particles=num_particles, vectorize_particles=True, + max_plate_nesting=1) + assert_ok(model, guide, elbo) + + +@pytest.mark.stage("funsor") +@pytest.mark.parametrize("num_particles", [1, 2]) +def test_collapse_normal_mvn_mvn(num_particles): + T, d, S = 5, 3, 4 + data = torch.randn((T, d, S)) + + def model(): + x = pyro.sample("x", dist.Normal(0., 1.)).unsqueeze(-1) + with poutine.collapse(): + with pyro.plate("d", d, dim=-1): + expand_shape = (d, S) if num_particles == 1 else (num_particles, d, S) + beta0 = pyro.sample("beta0", dist.Normal(x, 1.).expand(expand_shape).to_event(1)) + beta = pyro.sample("beta", dist.MultivariateNormal(beta0, torch.eye(S))) + + mean = torch.ones((T, d)) @ beta + with pyro.plate("data", T, dim=-1): + pyro.sample("obs", dist.MultivariateNormal(mean, torch.eye(S)), obs=data) + + def guide(): + loc = pyro.param("loc", torch.tensor(0.)) + scale = pyro.param("scale", torch.tensor(1.), constraint=constraints.positive) + pyro.sample("x", dist.Normal(loc, scale)) + + elbo = Trace_ELBO(num_particles=num_particles, vectorize_particles=True, + max_plate_nesting=1) + assert_ok(model, guide, elbo) + + @pytest.mark.stage("funsor") @pytest.mark.parametrize("num_particles", [1, 2]) def test_collapse_beta_bernoulli(num_particles):