Skip to content

Commit

Permalink
Fixes for Pyro 1.9 compatibility (#518)
Browse files Browse the repository at this point in the history
* Fixes for Pyro 1.9 compatibility

* format with updated black

* fix version

* unxfail cases
  • Loading branch information
eb8680 committed Feb 16, 2024
1 parent 29ff860 commit 90675c0
Show file tree
Hide file tree
Showing 10 changed files with 49 additions and 23 deletions.
3 changes: 2 additions & 1 deletion chirho/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
Project short description.
"""
__version__ = "0.0.1"

__version__ = "0.2.0"
7 changes: 6 additions & 1 deletion chirho/interventional/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,9 @@ def _pyro_post_sample(self, msg):
)


do = pyro.poutine.handlers._make_handler(Interventions)[1]
if isinstance(pyro.poutine.handlers._make_handler(Interventions), tuple):
do = pyro.poutine.handlers._make_handler(Interventions)[1]
else:

@pyro.poutine.handlers._make_handler(Interventions)
def do(fn: Callable, actions: Mapping[Hashable, AtomicIntervention[T]]): ...
8 changes: 7 additions & 1 deletion chirho/observational/handlers/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,10 @@ def _pyro_sample(self, msg):
self._current_site = None


condition = pyro.poutine.handlers._make_handler(Observations)[1]
if isinstance(pyro.poutine.handlers._make_handler(Observations), tuple):
# backwards compatibility
condition = pyro.poutine.handlers._make_handler(Observations)[1]
else:

@pyro.poutine.handlers._make_handler(Observations)
def condition(fn: Callable, data: Mapping[str, Observation[T]]): ...
1 change: 1 addition & 0 deletions chirho/robust/internals/nmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ class BatchedNMCLogMarginalLikelihood(Generic[P, T], torch.nn.Module):
used to approximate marginal distribution, defaults to 1
:type num_samples: int, optional
"""

model: Callable[P, Any]
guide: Optional[Callable[P, Any]]
num_samples: int
Expand Down
3 changes: 1 addition & 2 deletions chirho/robust/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
class Functional(Protocol[P, S]):
def __call__(
self, __model: Callable[P, Any], *models: Callable[P, Any]
) -> Callable[P, S]:
...
) -> Callable[P, S]: ...


def influence_fn(
Expand Down
5 changes: 1 addition & 4 deletions tests/observational/test_cut_posterior_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,7 @@ def run_svi_inference(model, n_steps=1000, verbose=True, lr=0.03, **model_kwargs
def analytical_linear_gaussian_cut_posterior(data):
post_sd_mod_one = math.sqrt((1 + NUM_SAMPS_MODULE_ONE / SIGMA_ONE**2) ** (-1))
pr_eta_cut = dist.Normal(
1
/ SIGMA_ONE**2
* data["w"].sum()
/ (1 + NUM_SAMPS_MODULE_ONE / SIGMA_ONE**2),
1 / SIGMA_ONE**2 * data["w"].sum() / (1 + NUM_SAMPS_MODULE_ONE / SIGMA_ONE**2),
scale=post_sd_mod_one,
)
post_mean_mod_two = lambda eta: ( # noqa
Expand Down
9 changes: 6 additions & 3 deletions tests/robust/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@
(SimpleModel, lambda _: SimpleGuide(), {"y"}, None),
pytest.param(
SimpleModel,
pyro.infer.autoguide.AutoNormal,
lambda m: pyro.infer.autoguide.AutoNormal(pyro.poutine.block(hide=["y"])(m)),
{"y"},
1,
marks=pytest.mark.xfail(
reason="torch.func autograd doesnt work with PyroParam"
marks=(
[pytest.mark.xfail(reason="torch.func autograd doesnt work with PyroParam")]
if tuple(map(int, pyro.__version__.split("+")[0].split(".")[:3]))
<= (1, 8, 6)
else []
),
),
]
Expand Down
8 changes: 5 additions & 3 deletions tests/robust/test_internals_compositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@ def test_empirical_fisher_vp_nmclikelihood_cg_composition():
)

v = {
k: torch.ones_like(v).unsqueeze(0)
if k != "model.guide.loc_a"
else torch.zeros_like(v).unsqueeze(0)
k: (
torch.ones_like(v).unsqueeze(0)
if k != "model.guide.loc_a"
else torch.zeros_like(v).unsqueeze(0)
)
for k, v in log_prob_params.items()
}

Expand Down
19 changes: 14 additions & 5 deletions tests/robust/test_internals_linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,20 @@ def test_batch_cg_solve(ndim: int, dtype: torch.dtype, num_particles: int):
(SimpleModel, lambda _: SimpleGuide(), {"y"}, None),
pytest.param(
SimpleModel,
pyro.infer.autoguide.AutoNormal,
lambda m: pyro.infer.autoguide.AutoNormal(
pyro.poutine.block(
hide=[
"y",
]
)(m)
),
{"y"},
1,
marks=pytest.mark.xfail(
reason="torch.func autograd doesnt work with PyroParam"
marks=(
[pytest.mark.xfail(reason="torch.func autograd doesnt work with PyroParam")]
if tuple(map(int, pyro.__version__.split("+")[0].split(".")[:3]))
<= (1, 8, 6)
else []
),
),
]
Expand Down Expand Up @@ -117,7 +126,7 @@ def test_nmc_param_influence_smoke(
for k, v in test_datum_eif.items():
assert not torch.isnan(v).any(), f"eif for {k} had nans"
assert not torch.isinf(v).any(), f"eif for {k} had infs"
if not k.endswith("guide.loc_a"):
if not (k.endswith("guide.loc_a") or k.endswith("a_unconstrained")):
assert not torch.isclose(
v, torch.zeros_like(v)
).all(), f"eif for {k} was zero"
Expand Down Expand Up @@ -162,7 +171,7 @@ def test_nmc_param_influence_vmap_smoke(
for k, v in test_data_eif.items():
assert not torch.isnan(v).any(), f"eif for {k} had nans"
assert not torch.isinf(v).any(), f"eif for {k} had infs"
if not k.endswith("guide.loc_a"):
if not (k.endswith("guide.loc_a") or k.endswith("a_unconstrained")):
assert not torch.isclose(
v, torch.zeros_like(v)
).all(), f"eif for {k} was zero"
Expand Down
9 changes: 6 additions & 3 deletions tests/robust/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@
(SimpleModel, lambda _: SimpleGuide(), {"y"}, None),
pytest.param(
SimpleModel,
pyro.infer.autoguide.AutoNormal,
lambda m: pyro.infer.autoguide.AutoNormal(pyro.poutine.block(hide=["y"])(m)),
{"y"},
1,
marks=pytest.mark.xfail(
reason="torch.func autograd doesnt work with PyroParam"
marks=(
[pytest.mark.xfail(reason="torch.func autograd doesnt work with PyroParam")]
if tuple(map(int, pyro.__version__.split("+")[0].split(".")[:3]))
<= (1, 8, 6)
else []
),
),
]
Expand Down

0 comments on commit 90675c0

Please sign in to comment.