diff --git a/chirho/observational/handlers/condition.py b/chirho/observational/handlers/condition.py index 01a52fa3b..3bd4e6614 100644 --- a/chirho/observational/handlers/condition.py +++ b/chirho/observational/handlers/condition.py @@ -1,12 +1,12 @@ from __future__ import annotations -from typing import Callable, Generic, Hashable, Mapping, TypeVar, Union +from typing import Callable, Generic, Mapping, TypeVar, Union import pyro import torch from chirho.observational.internals import ObserveNameMessenger -from chirho.observational.ops import AtomicObservation, observe +from chirho.observational.ops import Observation, observe T = TypeVar("T") R = Union[float, torch.Tensor] @@ -64,7 +64,9 @@ class Observations(Generic[T], ObserveNameMessenger): a richer set of observational data types and enables counterfactual inference. """ - def __init__(self, data: Mapping[Hashable, AtomicObservation[T]]): + data: Mapping[str, Observation[T]] + + def __init__(self, data: Mapping[str, Observation[T]]): self.data = data super().__init__() diff --git a/chirho/robust/__init__.py b/chirho/robust/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/chirho/robust/handlers/__init__.py b/chirho/robust/handlers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/chirho/robust/handlers/estimators.py b/chirho/robust/handlers/estimators.py new file mode 100644 index 000000000..eb6e8d6ee --- /dev/null +++ b/chirho/robust/handlers/estimators.py @@ -0,0 +1,55 @@ +from typing import Any, Callable, TypeVar + +import torch +from typing_extensions import ParamSpec + +from chirho.robust.ops import Functional, Point, influence_fn + +P = ParamSpec("P") +S = TypeVar("S") +T = TypeVar("T") + + +def one_step_corrected_estimator( + functional: Functional[P, S], + *test_points: Point[T], + **influence_kwargs, +) -> Functional[P, S]: + """ + Returns a functional that computes the one-step correction for the + functional at a specified set of test points as discussed in [1]. + + :param functional: model summary functional of interest + :param test_points: points at which to compute the one-step correction + :return: functional to compute the one-step correction + + **References** + + [1] `Semiparametric doubly robust targeted double machine learning: a review`, + Edward H. Kennedy, 2022. + """ + influence_kwargs_one_step = influence_kwargs.copy() + influence_kwargs_one_step["pointwise_influence"] = False + eif_fn = influence_fn(functional, *test_points, **influence_kwargs_one_step) + + def _corrected_functional(*model: Callable[P, Any]) -> Callable[P, S]: + plug_in_estimator = functional(*model) + correction_estimator = eif_fn(*model) + + def _estimator(*args, **kwargs) -> S: + plug_in_estimate = plug_in_estimator(*args, **kwargs) + correction = correction_estimator(*args, **kwargs) + + flat_plug_in_estimate, treespec = torch.utils._pytree.tree_flatten( + plug_in_estimate + ) + flat_correction, _ = torch.utils._pytree.tree_flatten(correction) + + return torch.utils._pytree.tree_unflatten( + [a + b for a, b in zip(flat_plug_in_estimate, flat_correction)], + treespec, + ) + + return _estimator + + return _corrected_functional diff --git a/chirho/robust/handlers/predictive.py b/chirho/robust/handlers/predictive.py new file mode 100644 index 000000000..f73c38bfd --- /dev/null +++ b/chirho/robust/handlers/predictive.py @@ -0,0 +1,140 @@ +from typing import Any, Callable, Generic, Optional, TypeVar + +import pyro +import torch +from typing_extensions import ParamSpec + +from chirho.indexed.handlers import IndexPlatesMessenger +from chirho.robust.internals.nmc import BatchedLatents +from chirho.robust.internals.utils import bind_leftmost_dim +from chirho.robust.ops import Point + +P = ParamSpec("P") +S = TypeVar("S") +T = TypeVar("T") + + +class PredictiveModel(Generic[P, T], torch.nn.Module): + """ + Given a Pyro model and guide, constructs a new model that behaves as if + the latent ``sample`` sites in the original model (i.e. the prior) + were replaced by their counterparts in the guide (i.e. the posterior). + + .. note:: Sites that only appear in the model are annotated in traces + produced by the predictive model with ``infer={"_model_predictive_site": True}`` . + + :param model: Pyro model. + :param guide: Pyro guide. + """ + + model: Callable[P, T] + guide: Optional[Callable[P, Any]] + + def __init__( + self, + model: Callable[P, T], + guide: Optional[Callable[P, Any]] = None, + ): + super().__init__() + self.model = model + self.guide = guide + + def forward(self, *args: P.args, **kwargs: P.kwargs) -> T: + """ + Returns a sample from the posterior predictive distribution. + + :return: Sample from the posterior predictive distribution. + :rtype: T + """ + with pyro.poutine.infer_config( + config_fn=lambda msg: {"_model_predictive_site": False} + ): + with pyro.poutine.trace() as guide_tr: + if self.guide is not None: + self.guide(*args, **kwargs) + + block_guide_sample_sites = pyro.poutine.block( + hide=[ + name + for name, node in guide_tr.trace.nodes.items() + if node["type"] == "sample" + ] + ) + + with pyro.poutine.infer_config( + config_fn=lambda msg: {"_model_predictive_site": True} + ): + with block_guide_sample_sites: + with pyro.poutine.replay(trace=guide_tr.trace): + return self.model(*args, **kwargs) + + +class PredictiveFunctional(Generic[P, T], torch.nn.Module): + """ + Functional that returns a batch of samples from the predictive + distribution of a Pyro model. As with ``pyro.infer.Predictive`` , + the returned values are batched along their leftmost positional dimension. + + Similar to ``pyro.infer.Predictive(model, guide, num_samples, parallel=True)`` + when :class:`~chirho.robust.handlers.predictive.PredictiveModel` is used to construct + the ``model`` argument and infer the ``sample`` sites whose values should be returned, + and uses :class:`~BatchedLatents` to parallelize over samples from the model. + + .. warning:: ``PredictiveFunctional`` currently applies its own internal instance of + :class:`~chirho.indexed.handlers.IndexPlatesMessenger` , + so it may not behave as expected if used within another enclosing + :class:`~chirho.indexed.handlers.IndexPlatesMessenger` context. + + :param model: Pyro model. + :param num_samples: Number of samples to return. + """ + + model: Callable[P, Any] + num_samples: int + + def __init__( + self, + model: torch.nn.Module, + *, + num_samples: int = 1, + max_plate_nesting: Optional[int] = None, + name: str = "__particles_predictive", + ): + super().__init__() + self.model = model + self.num_samples = num_samples + self._first_available_dim = ( + -max_plate_nesting - 1 if max_plate_nesting is not None else None + ) + self._mc_plate_name = name + + def forward(self, *args: P.args, **kwargs: P.kwargs) -> Point[T]: + """ + Returns a batch of samples from the posterior predictive distribution. + + :return: Dictionary of samples from the posterior predictive distribution. + :rtype: Point[T] + """ + with IndexPlatesMessenger(first_available_dim=self._first_available_dim): + with pyro.poutine.trace() as model_tr: + with BatchedLatents(self.num_samples, name=self._mc_plate_name): + with pyro.poutine.infer_config( + config_fn=lambda msg: { + "_model_predictive_site": msg["infer"].get( + "_model_predictive_site", True + ) + } + ): + self.model(*args, **kwargs) + + return { + name: bind_leftmost_dim( + node["value"], + self._mc_plate_name, + event_dim=len(node["fn"].event_shape), + ) + for name, node in model_tr.trace.nodes.items() + if node["type"] == "sample" + and not pyro.poutine.util.site_is_subsample(node) + and node["infer"].get("_model_predictive_site", False) + } diff --git a/chirho/robust/internals/__init__.py b/chirho/robust/internals/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/chirho/robust/internals/linearize.py b/chirho/robust/internals/linearize.py new file mode 100644 index 000000000..29447c736 --- /dev/null +++ b/chirho/robust/internals/linearize.py @@ -0,0 +1,387 @@ +import functools +from typing import Any, Callable, Optional, TypeVar + +import pyro +import torch +from typing_extensions import Concatenate, ParamSpec + +from chirho.robust.internals.nmc import BatchedNMCLogMarginalLikelihood +from chirho.robust.internals.utils import ( + ParamDict, + make_flatten_unflatten, + make_functional_call, + reset_rng_state, +) +from chirho.robust.ops import Point + +P = ParamSpec("P") +Q = ParamSpec("Q") +S = TypeVar("S") +T = TypeVar("T") + + +def _flat_conjugate_gradient_solve( + f_Ax: Callable[[torch.Tensor], torch.Tensor], + b: torch.Tensor, + *, + cg_iters: Optional[int] = None, + residual_tol: float = 1e-3, +) -> torch.Tensor: + """ + Use Conjugate Gradient iteration to solve Ax = b. Demmel p 312. + + :param f_Ax: a function to compute matrix vector products over a batch + of vectors ``x``. + :type f_Ax: Callable[[torch.Tensor], torch.Tensor] + :param b: batch of right hand sides of the equation to solve. + :type b: torch.Tensor + :param cg_iters: number of conjugate iterations to run, defaults to None + :type cg_iters: Optional[int], optional + :param residual_tol: tolerance for convergence, defaults to 1e-3 + :type residual_tol: float, optional + :return: batch of solutions ``x*`` for equation Ax = b. + :rtype: torch.Tensor + + .. note:: + + Code is adapted from + https://github.com/rlworkgroup/garage/blob/master/src/garage/torch/optimizers/conjugate_gradient_optimizer.py # noqa: E501 + + """ + assert len(b.shape), "b must be a 2D matrix" + + if cg_iters is None: + cg_iters = b.shape[1] + else: + cg_iters = min(cg_iters, b.shape[1]) + + def _batched_dot(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + return (x1 * x2).sum(axis=-1) # type: ignore + + def _batched_product(a: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + return a.unsqueeze(0).t() * B + + p = b.clone() + r = b.clone() + x = torch.zeros_like(b) + z = f_Ax(p) + rdotr = _batched_dot(r, r) + v = rdotr / _batched_dot(p, z) + newrdotr = rdotr + mu = newrdotr / rdotr + zeros_xr = torch.zeros_like(x) + for _ in range(cg_iters): + not_converged = rdotr > residual_tol + not_converged_broadcasted = not_converged.unsqueeze(0).t() + z = torch.where(not_converged_broadcasted, f_Ax(p), z) + v = torch.where(not_converged, rdotr / _batched_dot(p, z), v) + x += torch.where(not_converged_broadcasted, _batched_product(v, p), zeros_xr) + r -= torch.where(not_converged_broadcasted, _batched_product(v, z), zeros_xr) + newrdotr = torch.where(not_converged, _batched_dot(r, r), newrdotr) + mu = torch.where(not_converged, newrdotr / rdotr, mu) + p = torch.where(not_converged_broadcasted, r + _batched_product(mu, p), p) + rdotr = torch.where(not_converged, newrdotr, rdotr) + if torch.all(~not_converged): + return x + return x + + +def conjugate_gradient_solve(f_Ax: Callable[[T], T], b: T, **kwargs) -> T: + """ + Use Conjugate Gradient iteration to solve Ax = b. + + :param f_Ax: a function to compute matrix vector products over a batch + of vectors ``x``. + :type f_Ax: Callable[[T], T] + :param b: batch of right hand sides of the equation to solve. + :type b: T + :return: batch of solutions ``x*`` for equation Ax = b. + :rtype: T + """ + flatten, unflatten = make_flatten_unflatten(b) + + def f_Ax_flat(v: torch.Tensor) -> torch.Tensor: + v_unflattened: T = unflatten(v) + result_unflattened = f_Ax(v_unflattened) + return flatten(result_unflattened) + + return unflatten(_flat_conjugate_gradient_solve(f_Ax_flat, flatten(b), **kwargs)) + + +def make_empirical_fisher_vp( + batched_func_log_prob: Callable[Concatenate[ParamDict, Point[T], P], torch.Tensor], + log_prob_params: ParamDict, + data: Point[T], + *args: P.args, + **kwargs: P.kwargs, +) -> Callable[[ParamDict], ParamDict]: + r""" + Returns a function that computes the empirical Fisher vector product for an arbitrary + vector :math:`v` using only Hessian vector products via a batched version of + Perlmutter's trick [1]. + + .. math:: + + -\frac{1}{N} \sum_{n=1}^N \nabla_{\phi}^2 \log \tilde{p}_{\phi}(x_n) v, + + where :math:`\phi` corresponds to ``log_prob_params``, :math:`\tilde{p}_{\phi}` denotes the + predictive distribution ``log_prob``, and :math:`x_n` are the data points in ``data``. + + :param func_log_prob: computes the log probability of ``data`` given ``log_prob_params`` + :type func_log_prob: Callable[Concatenate[ParamDict, Point[T], P], torch.Tensor] + :param log_prob_params: parameters of the predictive distribution + :type log_prob_params: ParamDict + :param data: data points + :type data: Point[T] + :param is_batched: if ``False``, ``func_log_prob`` is batched over ``data`` + using ``torch.func.vmap``. Otherwise, assumes ``func_log_prob`` is already batched + over multiple data points. ``Defaults to False``. + :type is_batched: bool, optional + :return: a function that computes the empirical Fisher vector product for an arbitrary + vector :math:`v` + :rtype: Callable[[ParamDict], ParamDict] + + **Example usage**: + + .. code-block:: python + + import pyro + import pyro.distributions as dist + import torch + + from chirho.robust.internals.linearize import make_empirical_fisher_vp + + pyro.settings.set(module_local_params=True) + + + class GaussianModel(pyro.nn.PyroModule): + def __init__(self, cov_mat: torch.Tensor): + super().__init__() + self.register_buffer("cov_mat", cov_mat) + + def forward(self, loc): + pyro.sample( + "x", dist.MultivariateNormal(loc=loc, covariance_matrix=self.cov_mat) + ) + + + def gaussian_log_prob(params, data_point, cov_mat): + with pyro.validation_enabled(False): + return dist.MultivariateNormal( + loc=params["loc"], covariance_matrix=cov_mat + ).log_prob(data_point["x"]) + + + v = torch.tensor([1.0, 0.0], requires_grad=False) + loc = torch.ones(2, requires_grad=True) + cov_mat = torch.ones(2, 2) + torch.eye(2) + + func_log_prob = gaussian_log_prob + log_prob_params = {"loc": loc} + N_monte_carlo = 10000 + data = pyro.infer.Predictive(GaussianModel(cov_mat), num_samples=N_monte_carlo)(loc) + empirical_fisher_vp_func = make_empirical_fisher_vp( + func_log_prob, log_prob_params, data, cov_mat=cov_mat + ) + + empirical_fisher_vp = empirical_fisher_vp_func({"loc": v})["loc"] + + # Closed form solution for the Fisher vector product + # See "Multivariate normal distribution" in https://en.wikipedia.org/wiki/Fisher_information + prec_matrix = torch.linalg.inv(cov_mat) + true_vp = prec_matrix.mv(v) + + assert torch.all(torch.isclose(empirical_fisher_vp, true_vp, atol=0.1)) + + + **References** + + [1] `Fast Exact Multiplication by the Hessian`, + Barak A. Pearlmutter, 1999. + """ + N = data[next(iter(data))].shape[0] # type: ignore + mean_vector = 1 / N * torch.ones(N) + + def bound_batched_func_log_prob(params: ParamDict) -> torch.Tensor: + return batched_func_log_prob(params, data, *args, **kwargs) + + def _empirical_fisher_vp(v: ParamDict) -> ParamDict: + def jvp_fn(log_prob_params: ParamDict) -> torch.Tensor: + return torch.func.jvp( + bound_batched_func_log_prob, (log_prob_params,), (v,) + )[1] + + # Perlmutter's trick + vjp_fn = torch.func.vjp(jvp_fn, log_prob_params)[1] + return vjp_fn(-1 * mean_vector)[0] # Fisher = -E[Hessian] + + return _empirical_fisher_vp + + +def linearize( + *models: Callable[P, Any], + num_samples_outer: int, + num_samples_inner: Optional[int] = None, + max_plate_nesting: Optional[int] = None, + cg_iters: Optional[int] = None, + residual_tol: float = 1e-4, + pointwise_influence: bool = True, +) -> Callable[Concatenate[Point[T], P], ParamDict]: + r""" + Returns the influence function associated with the parameters + of a normalized probabilistic program ``model``. This function + computes the following quantity at an arbitrary point :math:`x^{\prime}`: + + .. math:: + + \left[-\frac{1}{N} \sum_{n=1}^N \nabla_{\phi}^2 \log \tilde{p}_{\phi}(x_n) \right] + \nabla_{\phi} \log \tilde{p}_{\phi}(x^{\prime}), \quad + \tilde{p}_{\phi}(x) = \int p_{\phi}(x, \theta) d\theta, + + where :math:`\phi` corresponds to ``log_prob_params``, + :math:`p(x, \theta)` denotes the ``model``, + :math:`\tilde{p}_{\phi}` denotes the predictive distribution ``log_prob`` induced + from the ``model``, and :math:`\{x_n\}_{n=1}^N` are the + data points drawn iid from the predictive distribution. + + :param model: Python callable containing Pyro primitives. + :type model: Callable[P, Any] + :param num_samples_outer: number of Monte Carlo samples to + approximate Fisher information in :func:`make_empirical_fisher_vp` + :type num_samples_outer: int + :param num_samples_inner: number of Monte Carlo samples used in + :class:`BatchedNMCLogPredictiveLikelihood`. Defaults to ``num_samples_outer**2``. + :type num_samples_inner: Optional[int], optional + :param max_plate_nesting: bound on max number of nested :func:`pyro.plate` + contexts. Defaults to ``None``. + :type max_plate_nesting: Optional[int], optional + :param cg_iters: number of conjugate gradient steps used to + invert Fisher information matrix, defaults to None + :type cg_iters: Optional[int], optional + :param residual_tol: tolerance used to terminate conjugate gradients + early, defaults to 1e-4 + :type residual_tol: float, optional + :param pointwise_influence: if ``True``, computes the influence function at each + point in ``points``. If ``False``, computes the efficient influence averaged + over ``points``. Defaults to True. + :type pointwise_influence: bool, optional + :return: the influence function associated with the parameters + :rtype: Callable[Concatenate[Point[T], P], ParamDict] + + **Example usage**: + + .. code-block:: python + + import pyro + import pyro.distributions as dist + import torch + + from chirho.robust.handlers.predictive import PredictiveModel + from chirho.robust.internals.linearize import linearize + + pyro.settings.set(module_local_params=True) + + + class SimpleModel(pyro.nn.PyroModule): + def forward(self): + a = pyro.sample("a", dist.Normal(0, 1)) + with pyro.plate("data", 3, dim=-1): + b = pyro.sample("b", dist.Normal(a, 1)) + return pyro.sample("y", dist.Normal(b, 1)) + + + class SimpleGuide(torch.nn.Module): + def __init__(self): + super().__init__() + self.loc_a = torch.nn.Parameter(torch.rand(())) + self.loc_b = torch.nn.Parameter(torch.rand((3,))) + + def forward(self): + a = pyro.sample("a", dist.Normal(self.loc_a, 1)) + with pyro.plate("data", 3, dim=-1): + b = pyro.sample("b", dist.Normal(self.loc_b, 1)) + return {"a": a, "b": b} + + model = SimpleModel() + guide = SimpleGuide() + predictive = pyro.infer.Predictive( + model, guide=guide, num_samples=10, return_sites=["y"] + ) + points = predictive() + influence = linearize( + PredictiveModel(model, guide), + num_samples_outer=1000, + num_samples_inner=1000, + ) + + influence(points) + + .. note:: + + * Since the efficient influence function is approximated using Monte Carlo, the result + of this function is stochastic, i.e., evaluating this function on the same ``points`` + can result in different values. To reduce variance, increase ``num_samples_outer`` and + ``num_samples_inner`` in ``linearize_kwargs``. + + * Currently, ``model`` cannot contain any ``pyro.param`` statements. + This issue will be addressed in a future release: + https://github.com/BasisResearch/chirho/issues/393. + """ + if len(models) > 1: + raise NotImplementedError("Only unary version of linearize is implemented.") + else: + (model,) = models + + assert isinstance(model, torch.nn.Module) + if num_samples_inner is None: + num_samples_inner = num_samples_outer**2 + + predictive = pyro.infer.Predictive( + model, + num_samples=num_samples_outer, + parallel=True, + ) + + batched_log_prob = BatchedNMCLogMarginalLikelihood( + model, num_samples=num_samples_inner, max_plate_nesting=max_plate_nesting + ) + log_prob_params, batched_func_log_prob = make_functional_call(batched_log_prob) + log_prob_params_numel: int = sum(p.numel() for p in log_prob_params.values()) + if cg_iters is None: + cg_iters = log_prob_params_numel + else: + cg_iters = min(cg_iters, log_prob_params_numel) + cg_solver = functools.partial( + conjugate_gradient_solve, cg_iters=cg_iters, residual_tol=residual_tol + ) + + def _fn( + points: Point[T], + *args: P.args, + **kwargs: P.kwargs, + ) -> ParamDict: + with torch.no_grad(): + data: Point[T] = predictive(*args, **kwargs) + data = {k: data[k] for k in points.keys()} + fvp = make_empirical_fisher_vp( + batched_func_log_prob, log_prob_params, data, *args, **kwargs + ) + pinned_fvp = reset_rng_state(pyro.util.get_rng_state())(fvp) + pinned_fvp_batched = torch.func.vmap( + lambda v: pinned_fvp(v), randomness="different" + ) + + def bound_batched_func_log_prob(p: ParamDict) -> torch.Tensor: + return batched_func_log_prob(p, points, *args, **kwargs) + + if pointwise_influence: + score_fn = torch.func.jacrev(bound_batched_func_log_prob) + point_scores = score_fn(log_prob_params) + else: + score_fn = torch.func.vjp(bound_batched_func_log_prob, log_prob_params)[1] + N_pts = points[next(iter(points))].shape[0] # type: ignore + point_scores = score_fn(1 / N_pts * torch.ones(N_pts))[0] + point_scores = {k: v.unsqueeze(0) for k, v in point_scores.items()} + return cg_solver(pinned_fvp_batched, point_scores) + + return _fn diff --git a/chirho/robust/internals/nmc.py b/chirho/robust/internals/nmc.py new file mode 100644 index 000000000..342abdcc0 --- /dev/null +++ b/chirho/robust/internals/nmc.py @@ -0,0 +1,213 @@ +import collections +import math +import typing +from typing import Any, Callable, Generic, Optional, TypeVar + +import pyro +import torch +from typing_extensions import ParamSpec + +from chirho.indexed.handlers import IndexPlatesMessenger +from chirho.indexed.ops import get_index_plates, indices_of +from chirho.observational.handlers.condition import Observations +from chirho.robust.internals.utils import ( + bind_leftmost_dim, + get_importance_traces, + site_is_delta, + unbind_leftmost_dim, +) +from chirho.robust.ops import Point + +pyro.settings.set(module_local_params=True) + +P = ParamSpec("P") +Q = ParamSpec("Q") +S = TypeVar("S") +T = TypeVar("T") + + +class BatchedLatents(pyro.poutine.messenger.Messenger): + """ + Effect handler that adds a fresh batch dimension to all latent ``sample`` sites. + Similar to wrapping a Pyro model in a ``pyro.plate`` context, but uses the machinery + in ``chirho.indexed`` to automatically allocate and track the fresh batch dimension + based on the ``name`` argument to ``BatchedLatents`` . + + .. warning:: Must be used in conjunction with :class:`~chirho.indexed.handlers.IndexPlatesMessenger` . + + :param int num_particles: Number of particles to use for parallelization. + :param str name: Name of the fresh batch dimension. + """ + + num_particles: int + name: str + + def __init__(self, num_particles: int, *, name: str = "__particles_mc"): + assert num_particles > 0 + assert len(name) > 0 + self.num_particles = num_particles + self.name = name + super().__init__() + + def _pyro_sample(self, msg: dict) -> None: + if ( + self.num_particles > 1 + and msg["value"] is None + and not pyro.poutine.util.site_is_factor(msg) + and not pyro.poutine.util.site_is_subsample(msg) + and not site_is_delta(msg) + and self.name not in indices_of(msg["fn"]) + ): + msg["fn"] = unbind_leftmost_dim( + msg["fn"].expand((1,) + msg["fn"].batch_shape), + self.name, + size=self.num_particles, + ) + + +class BatchedObservations(Generic[T], Observations[T]): + """ + Effect handler that takes a dictionary of observation values for ``sample`` sites + that are assumed to be batched along their leftmost dimension, adds a fresh named + dimension using the machinery in ``chirho.indexed``, and reshapes the observation + values so that the new ``chirho.observational.observe`` sites are batched along + the fresh named dimension. + + Useful in combination with ``pyro.infer.Predictive`` which returns a dictionary + of values whose leftmost dimension is a batch dimension over independent samples. + + .. warning:: Must be used in conjunction with :class:`~chirho.indexed.handlers.IndexPlatesMessenger` . + + :param Point[T] data: Dictionary of observation values. + :param str name: Name of the fresh batch dimension. + """ + + name: str + + def __init__(self, data: Point[T], *, name: str = "__particles_data"): + assert len(name) > 0 + self.name = name + super().__init__(data) + + def _pyro_observe(self, msg: dict) -> None: + super()._pyro_observe(msg) + if msg["kwargs"]["name"] in self.data: + rv, obs = msg["args"] + event_dim = ( + len(rv.event_shape) + if hasattr(rv, "event_shape") + else msg["kwargs"].get("event_dim", 0) + ) + batch_obs = unbind_leftmost_dim(obs, self.name, event_dim=event_dim) + msg["args"] = (rv, batch_obs) + + +class BatchedNMCLogMarginalLikelihood(Generic[P, T], torch.nn.Module): + r""" + Approximates the log marginal likelihood induced by ``model`` and ``guide`` + using importance sampling at an arbitrary batch of :math:`N` + points :math:`\{x_n\}_{n=1}^N`. + + .. math:: + \log \left(\frac{1}{M} \sum_{m=1}^M \frac{p(x_n \mid \theta_m) p(\theta_m) )}{q_{\phi}(\theta_m)} \right), + \quad \theta_m \sim q_{\phi}(\theta), + + where :math:`q_{\phi}(\theta)` is the guide, and :math:`p(x_n \mid \theta_m) p(\theta_m)` + is the model joint density of the data and the latents sampled from the guide. + + :param model: Python callable containing Pyro primitives. + :type model: torch.nn.Module + :param guide: Python callable containing Pyro primitives. + Must only contain continuous latent variables. + :type guide: torch.nn.Module + :param num_samples: Number of Monte Carlo draws :math:`M` + used to approximate marginal distribution, defaults to 1 + :type num_samples: int, optional + """ + model: Callable[P, Any] + guide: Optional[Callable[P, Any]] + num_samples: int + + def __init__( + self, + model: torch.nn.Module, + guide: Optional[torch.nn.Module] = None, + *, + num_samples: int = 1, + max_plate_nesting: Optional[int] = None, + data_plate_name: str = "__particles_data", + mc_plate_name: str = "__particles_mc", + ): + super().__init__() + self.model = model + self.guide = guide + self.num_samples = num_samples + self._first_available_dim = ( + -max_plate_nesting - 1 if max_plate_nesting is not None else None + ) + self._data_plate_name = data_plate_name + self._mc_plate_name = mc_plate_name + + def forward( + self, data: Point[T], *args: P.args, **kwargs: P.kwargs + ) -> torch.Tensor: + """ + Computes the log predictive likelihood of ``data`` given ``model`` and ``guide``. + + :param data: Dictionary of observations. + :type data: Point[T] + :return: Log marginal likelihood at each datapoint. + :rtype: torch.Tensor + """ + get_nmc_traces = get_importance_traces(self.model, self.guide) + + with IndexPlatesMessenger(first_available_dim=self._first_available_dim): + with BatchedLatents(self.num_samples, name=self._mc_plate_name): + with BatchedObservations(data, name=self._data_plate_name): + model_trace, guide_trace = get_nmc_traces(*args, **kwargs) + index_plates = get_index_plates() + + plate_name_to_dim = collections.OrderedDict( + (p, index_plates[p]) + for p in [self._mc_plate_name, self._data_plate_name] + if p in index_plates + ) + plate_frames = set(plate_name_to_dim.values()) + + log_weights = typing.cast(torch.Tensor, 0.0) + for site in model_trace.nodes.values(): + if site["type"] != "sample": + continue + site_log_prob = site["log_prob"] + for f in site["cond_indep_stack"]: + if f.dim is not None and f not in plate_frames: + site_log_prob = site_log_prob.sum(f.dim, keepdim=True) + log_weights = log_weights + site_log_prob + + for site in guide_trace.nodes.values(): + if site["type"] != "sample": + continue + site_log_prob = site["log_prob"] + for f in site["cond_indep_stack"]: + if f.dim is not None and f not in plate_frames: + site_log_prob = site_log_prob.sum(f.dim, keepdim=True) + log_weights = log_weights - site_log_prob + + # sum out particle dimension and discard + if self._mc_plate_name in index_plates: + log_weights = torch.logsumexp( + log_weights, + dim=plate_name_to_dim[self._mc_plate_name].dim, + keepdim=True, + ) - math.log(self.num_samples) + plate_name_to_dim.pop(self._mc_plate_name) + + # move data plate dimension to the left + for name in reversed(plate_name_to_dim.keys()): + log_weights = bind_leftmost_dim(log_weights, name) + + # pack log_weights by squeezing out rightmost dimensions + for _ in range(len(log_weights.shape) - len(plate_name_to_dim)): + log_weights = log_weights.squeeze(-1) + + return log_weights diff --git a/chirho/robust/internals/utils.py b/chirho/robust/internals/utils.py new file mode 100644 index 000000000..9289027a7 --- /dev/null +++ b/chirho/robust/internals/utils.py @@ -0,0 +1,272 @@ +import contextlib +import functools +import math +from typing import Any, Callable, Dict, Mapping, Optional, Tuple, TypeVar + +import pyro +import torch +from typing_extensions import Concatenate, ParamSpec + +from chirho.indexed.handlers import add_indices +from chirho.indexed.ops import IndexSet, get_index_plates, indices_of + +P = ParamSpec("P") +Q = ParamSpec("Q") +S = TypeVar("S") +T = TypeVar("T") + +ParamDict = Mapping[str, torch.Tensor] + + +@functools.singledispatch +def make_flatten_unflatten( + v, +) -> Tuple[Callable[[T], torch.Tensor], Callable[[torch.Tensor], T]]: + """ + Returns functions to flatten and unflatten an object. Used as a helper + in :func:`chirho.robust.internals.linearize.conjugate_gradient_solve` + + :param v: some object + :raises NotImplementedError: + :return: flatten and unflatten functions + :rtype: Tuple[Callable[[T], torch.Tensor], Callable[[torch.Tensor], T]] + """ + raise NotImplementedError + + +@make_flatten_unflatten.register(torch.Tensor) +def _make_flatten_unflatten_tensor(v: torch.Tensor): + """ + Returns functions to flatten and unflatten a `torch.Tensor`. + """ + batch_size = v.shape[0] + + def flatten(v: torch.Tensor) -> torch.Tensor: + r""" + Flatten a tensor into a single vector. + """ + return v.reshape((batch_size, -1)) + + def unflatten(x: torch.Tensor) -> torch.Tensor: + r""" + Unflatten a vector into a tensor. + """ + return x.reshape(v.shape) + + return flatten, unflatten + + +@make_flatten_unflatten.register(dict) +def _make_flatten_unflatten_dict(d: Dict[str, torch.Tensor]): + """ + Returns functions to flatten and unflatten a dictionary of `torch.Tensor`s. + """ + batch_size = next(iter(d.values())).shape[0] + + def flatten(d: Dict[str, torch.Tensor]) -> torch.Tensor: + r""" + Flatten a dictionary of tensors into a single vector. + """ + return torch.hstack([v.reshape((batch_size, -1)) for k, v in d.items()]) + + def unflatten(x: torch.Tensor) -> Dict[str, torch.Tensor]: + r""" + Unflatten a vector into a dictionary of tensors. + """ + return dict( + zip( + d.keys(), + [ + v_flat.reshape(v.shape) + for v, v_flat in zip( + d.values(), + torch.split( + x, + [int(v.numel() / batch_size) for k, v in d.items()], + dim=1, + ), + ) + ], + ) + ) + + return flatten, unflatten + + +def make_functional_call( + mod: Callable[P, T] +) -> Tuple[ParamDict, Callable[Concatenate[ParamDict, P], T]]: + """ + Converts a PyTorch module into a functional call for use with + functions in :class:`torch.func`. + + :param mod: PyTorch module + :type mod: Callable[P, T] + :return: parameter dictionary and functional call + :rtype: Tuple[ParamDict, Callable[Concatenate[ParamDict, P], T]] + """ + assert isinstance(mod, torch.nn.Module) + param_dict: ParamDict = dict(mod.named_parameters()) + + @torch.func.functionalize + def mod_func(params: ParamDict, *args: P.args, **kwargs: P.kwargs) -> T: + with pyro.validation_enabled(False): + return torch.func.functional_call(mod, params, args, dict(**kwargs)) + + return param_dict, mod_func + + +@pyro.poutine.block() +@pyro.validation_enabled(False) +@torch.no_grad() +def guess_max_plate_nesting( + model: Callable[P, Any], guide: Callable[P, Any], *args: P.args, **kwargs: P.kwargs +) -> int: + """ + Guesses the maximum plate nesting level by running `pyro.infer.Trace_ELBO` + + :param model: Python callable containing Pyro primitives. + :type model: Callable[P, Any] + :param guide: Python callable containing Pyro primitives. + :type guide: Callable[P, Any] + :return: maximum plate nesting level + :rtype: int + """ + elbo = pyro.infer.Trace_ELBO() + elbo._guess_max_plate_nesting(model, guide, args, kwargs) + return elbo.max_plate_nesting + + +@contextlib.contextmanager +def reset_rng_state(rng_state: T): + """ + Helper to temporarily reset the Pyro RNG state. + """ + try: + prev_rng_state: T = pyro.util.get_rng_state() + yield pyro.util.set_rng_state(rng_state) + finally: + pyro.util.set_rng_state(prev_rng_state) + + +@functools.singledispatch +def unbind_leftmost_dim(v, name: str, size: int = 1, **kwargs): + """ + Helper function to move the leftmost dimension of a ``torch.Tensor`` + or ``pyro.distributions.Distribution`` or other batched value + into a fresh named dimension using the machinery in ``chirho.indexed`` , + allocating a new dimension with the given name if necessary + via an enclosing :class:`~chirho.indexed.handlers.IndexPlatesMessenger` . + + .. warning:: Must be used in conjunction with :class:`~chirho.indexed.handlers.IndexPlatesMessenger` . + + :param v: Batched value. + :param name: Name of the fresh dimension. + :param size: Size of the fresh dimension. If 1, the size is inferred from ``v`` . + """ + raise NotImplementedError + + +@unbind_leftmost_dim.register +def _unbind_leftmost_dim_tensor( + v: torch.Tensor, name: str, size: int = 1, *, event_dim: int = 0 +) -> torch.Tensor: + size = max(size, v.shape[0]) + v = v.expand((size,) + v.shape[1:]) + + if name not in get_index_plates(): + add_indices(IndexSet(**{name: set(range(size))})) + + new_dim: int = get_index_plates()[name].dim + orig_shape = v.shape + while new_dim - event_dim < -len(v.shape): + v = v[None] + if v.shape[0] == 1 and orig_shape[0] != 1: + v = torch.transpose(v, -len(orig_shape), new_dim - event_dim) + return v + + +@unbind_leftmost_dim.register +def _unbind_leftmost_dim_distribution( + v: pyro.distributions.Distribution, name: str, size: int = 1, **kwargs +) -> pyro.distributions.Distribution: + size = max(size, v.batch_shape[0]) + if v.batch_shape[0] != 1: + raise NotImplementedError("Cannot freely reshape distribution") + + if name not in get_index_plates(): + add_indices(IndexSet(**{name: set(range(size))})) + + new_dim: int = get_index_plates()[name].dim + orig_shape = v.batch_shape + + new_shape = (size,) + (1,) * (-new_dim - len(orig_shape)) + orig_shape[1:] + return v.expand(new_shape) + + +@functools.singledispatch +def bind_leftmost_dim(v, name: str, **kwargs): + """ + Helper function to move a named dimension managed by ``chirho.indexed`` + into a new unnamed dimension to the left of all named dimensions in the value. + + .. warning:: Must be used in conjunction with :class:`~chirho.indexed.handlers.IndexPlatesMessenger` . + """ + raise NotImplementedError + + +@bind_leftmost_dim.register +def _bind_leftmost_dim_tensor( + v: torch.Tensor, name: str, *, event_dim: int = 0, **kwargs +) -> torch.Tensor: + if name not in indices_of(v, event_dim=event_dim): + return v + return torch.transpose( + v[None], -len(v.shape) - 1, get_index_plates()[name].dim - event_dim + ) + + +def get_importance_traces( + model: Callable[P, Any], + guide: Optional[Callable[P, Any]] = None, +) -> Callable[P, Tuple[pyro.poutine.Trace, pyro.poutine.Trace]]: + """ + Thin functional wrapper around :func:`~pyro.infer.enum.get_importance_trace` + that cleans up the original interface to avoid unnecessary arguments + and efficiently supports using the prior in a model as a default guide. + + :param model: Model to run. + :param guide: Guide to run. If ``None``, use the prior in ``model`` as a guide. + :returns: A function that takes the same arguments as ``model`` and ``guide`` and returns + a tuple of importance traces ``(model_trace, guide_trace)``. + """ + + def _fn( + *args: P.args, **kwargs: P.kwargs + ) -> Tuple[pyro.poutine.Trace, pyro.poutine.Trace]: + if guide is not None: + model_trace, guide_trace = pyro.infer.enum.get_importance_trace( + "flat", math.inf, model, guide, args, kwargs + ) + return model_trace, guide_trace + else: # use prior as default guide, but don't run model twice + model_trace, _ = pyro.infer.enum.get_importance_trace( + "flat", math.inf, model, lambda *_, **__: None, args, kwargs + ) + + guide_trace = model_trace.copy() + for name, node in list(guide_trace.nodes.items()): + if node["type"] != "sample": + del model_trace.nodes[name] + elif pyro.poutine.util.site_is_factor(node) or node["is_observed"]: + del guide_trace.nodes[name] + return model_trace, guide_trace + + return _fn + + +def site_is_delta(msg: dict) -> bool: + d = msg["fn"] + while hasattr(d, "base_dist"): + d = d.base_dist + return isinstance(d, pyro.distributions.Delta) diff --git a/chirho/robust/ops.py b/chirho/robust/ops.py new file mode 100644 index 000000000..86ed8f89d --- /dev/null +++ b/chirho/robust/ops.py @@ -0,0 +1,155 @@ +from typing import Any, Callable, Mapping, Protocol, TypeVar + +import torch +from typing_extensions import ParamSpec + +from chirho.observational.ops import Observation + +P = ParamSpec("P") +Q = ParamSpec("Q") +S = TypeVar("S", covariant=True) +T = TypeVar("T") + +Point = Mapping[str, Observation[T]] + + +class Functional(Protocol[P, S]): + def __call__( + self, __model: Callable[P, Any], *models: Callable[P, Any] + ) -> Callable[P, S]: + ... + + +def influence_fn( + functional: Functional[P, S], *points: Point[T], **linearize_kwargs +) -> Functional[P, S]: + """ + Returns a new functional that computes the efficient influence function for ``functional`` + at the given ``points`` with respect to the parameters of its probabilistic program arguments. + + :param functional: model summary of interest, which is a function of ``model`` + :param points: points for each input to ``functional`` at which to compute the efficient influence function + :return: functional that computes the efficient influence function for ``functional`` at ``points`` + + **Example usage**: + + .. code-block:: python + + import pyro + import pyro.distributions as dist + import torch + + from chirho.robust.handlers.predictive import PredictiveModel + from chirho.robust.ops import influence_fn + + pyro.settings.set(module_local_params=True) + + + class SimpleModel(pyro.nn.PyroModule): + def forward(self): + a = pyro.sample("a", dist.Normal(0, 1)) + with pyro.plate("data", 3, dim=-1): + b = pyro.sample("b", dist.Normal(a, 1)) + return pyro.sample("y", dist.Normal(b, 1)) + + + class SimpleGuide(torch.nn.Module): + def __init__(self): + super().__init__() + self.loc_a = torch.nn.Parameter(torch.rand(())) + self.loc_b = torch.nn.Parameter(torch.rand((3,))) + + def forward(self): + a = pyro.sample("a", dist.Normal(self.loc_a, 1)) + with pyro.plate("data", 3, dim=-1): + b = pyro.sample("b", dist.Normal(self.loc_b, 1)) + return {"a": a, "b": b} + + + class SimpleFunctional(torch.nn.Module): + def __init__(self, model, guide, num_monte_carlo=1000): + super().__init__() + self.model = model + self.guide = guide + self.num_monte_carlo = num_monte_carlo + + def forward(self): + with pyro.plate("monte_carlo_functional", size=self.num_monte_carlo, dim=-2): + posterior_guide_samples = pyro.poutine.trace(self.guide).get_trace() + model_at_theta = pyro.poutine.replay(trace=posterior_guide_samples)( + self.model + ) + model_samples = pyro.poutine.trace(model_at_theta).get_trace() + return model_samples.nodes["b"]["value"].mean(axis=0) + + + model = SimpleModel() + guide = SimpleGuide() + predictive = pyro.infer.Predictive( + model, guide=guide, num_samples=10, return_sites=["y"] + ) + points = predictive() + influence = influence_fn( + SimpleFunctional, + points, + num_samples_outer=1000, + num_samples_inner=1000, + )(PredictiveModel(model, guide)) + + influence() + + .. note:: + + * ``functional`` must compose with ``torch.func.jvp`` + * Since the efficient influence function is approximated using Monte Carlo, the result + of this function is stochastic, i.e., evaluating this function on the same ``points`` + can result in different values. To reduce variance, increase ``num_samples_outer`` and + ``num_samples_inner`` in ``linearize_kwargs``. + * Currently, ``model`` cannot contain any ``pyro.param`` statements. + This issue will be addressed in a future release: + https://github.com/BasisResearch/chirho/issues/393. + """ + from chirho.robust.internals.linearize import linearize + from chirho.robust.internals.utils import make_functional_call + + if len(points) != 1: + raise NotImplementedError( + "influence_fn currently only supports unary functionals" + ) + + def _influence_functional(*models: Callable[P, Any]) -> Callable[P, S]: + """ + Functional representing the efficient influence function of ``functional`` at ``points`` . + + :param models: Python callables containing Pyro primitives. + :return: efficient influence function for ``functional`` evaluated at ``model`` and ``points`` + """ + if len(models) != len(points): + raise ValueError("mismatch between number of models and points") + + linearized = linearize(*models, **linearize_kwargs) + target = functional(*models) + + # TODO check that target_params == model_params + assert isinstance(target, torch.nn.Module) + target_params, func_target = make_functional_call(target) + + def _fn(*args: P.args, **kwargs: P.kwargs) -> S: + """ + Evaluates the efficient influence function for ``functional`` at each + point in ``points``. + + :return: efficient influence function evaluated at each point in ``points`` or averaged + """ + param_eif = linearized(*points, *args, **kwargs) + return torch.vmap( + lambda d: torch.func.jvp( + lambda p: func_target(p, *args, **kwargs), (target_params,), (d,) + )[1], + in_dims=0, + randomness="different", + )(param_eif) + + return _fn + + return _influence_functional diff --git a/docs/source/index.rst b/docs/source/index.rst index dd14293bd..5ad12cdf5 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -42,6 +42,7 @@ Table of Contents observational indexed dynamical + robust explainable .. toctree:: diff --git a/docs/source/robust.rst b/docs/source/robust.rst new file mode 100644 index 000000000..172cfaf28 --- /dev/null +++ b/docs/source/robust.rst @@ -0,0 +1,47 @@ +Robust +====== + +.. automodule:: chirho.robust + :members: + :undoc-members: + +Operations +---------- + +.. automodule:: chirho.robust.ops + :members: + :undoc-members: + +Handlers +-------- + +.. automodule:: chirho.robust.handlers + :members: + :undoc-members: + +.. automodule:: chirho.robust.handlers.estimators + :members: + :undoc-members: + +.. automodule:: chirho.robust.handlers.predictive + :members: + :undoc-members: + +Internals +--------- + +.. automodule:: chirho.robust.internals + :members: + :undoc-members: + +.. automodule:: chirho.robust.internals.linearize + :members: + :undoc-members: + +.. automodule:: chirho.robust.internals.nmc + :members: + :undoc-members: + +.. automodule:: chirho.robust.internals.utils + :members: + :undoc-members: diff --git a/setup.py b/setup.py index 827384c3a..47c6dcd3d 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,7 @@ "pytorch-lightning", "scikit-image", "tensorboard", + "typing_extensions", ] DYNAMICAL_REQUIRE = ["torchdiffeq"] diff --git a/tests/robust/__init__.py b/tests/robust/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/robust/robust_fixtures.py b/tests/robust/robust_fixtures.py new file mode 100644 index 000000000..4496e7da6 --- /dev/null +++ b/tests/robust/robust_fixtures.py @@ -0,0 +1,230 @@ +import math +from typing import Callable, Optional, Tuple, TypedDict, TypeVar + +import pyro +import pyro.distributions as dist +import torch +from pyro.nn import PyroModule + +from chirho.observational.handlers import condition +from chirho.robust.internals.utils import ParamDict +from chirho.robust.ops import Point + +pyro.settings.set(module_local_params=True) +T = TypeVar("T") + + +class SimpleModel(PyroModule): + def __init__( + self, + link_fn: Callable[..., dist.Distribution] = lambda mu: dist.Normal(mu, 1.0), + ): + super().__init__() + self.link_fn = link_fn + + def forward(self): + a = pyro.sample("a", dist.Normal(0, 1)) + with pyro.plate("data", 3, dim=-1): + b = pyro.sample("b", dist.Normal(a, 1)) + return pyro.sample("y", dist.Normal(b, 1)) + + +class SimpleGuide(torch.nn.Module): + def __init__(self): + super().__init__() + self.loc_a = torch.nn.Parameter(torch.rand(())) + self.loc_b = torch.nn.Parameter(torch.rand((3,))) + + def forward(self): + a = pyro.sample("a", dist.Normal(self.loc_a, 1)) + with pyro.plate("data", 3, dim=-1): + b = pyro.sample("b", dist.Normal(self.loc_b, 1)) + return {"a": a, "b": b} + + +class GaussianModel(PyroModule): + def __init__(self, cov_mat: torch.Tensor): + super().__init__() + self.register_buffer("cov_mat", cov_mat) + + def forward(self, loc): + pyro.sample( + "x", dist.MultivariateNormal(loc=loc, covariance_matrix=self.cov_mat) + ) + + +# Note: `gaussian_log_prob` is separate from the GaussianModel above because of upstream obstacles +# in the interaction between `pyro.nn.PyroModule` and `torch.func`. +# See https://github.com/BasisResearch/chirho/issues/393 +def gaussian_log_prob(params: ParamDict, data_point: Point[T], cov_mat) -> T: + with pyro.validation_enabled(False): + return dist.MultivariateNormal( + loc=params["loc"], covariance_matrix=cov_mat + ).log_prob(data_point["x"]) + + +class DataConditionedModel(PyroModule): + r""" + Helper class for conditioning on data. + """ + + def __init__(self, model: PyroModule): + super().__init__() + self.model = model + + def forward(self, D: Point[torch.Tensor]): + with condition(data=D): + # Assume first dimension corresponds to # of datapoints + N = D[next(iter(D))].shape[0] + return self.model.forward(N=N) + + +class HighDimLinearModel(pyro.nn.PyroModule): + def __init__( + self, + p: int, + link_fn: Callable[..., dist.Distribution] = lambda mu: dist.Normal(mu, 1.0), + prior_scale: Optional[float] = None, + ): + super().__init__() + self.p = p + self.link_fn = link_fn + if prior_scale is None: + self.prior_scale = 1 / math.sqrt(self.p) + else: + self.prior_scale = prior_scale + + def sample_outcome_weights(self): + return pyro.sample( + "outcome_weights", + dist.Normal(0.0, self.prior_scale).expand((self.p,)).to_event(1), + ) + + def sample_intercept(self): + return pyro.sample("intercept", dist.Normal(0.0, 1.0)) + + def sample_propensity_weights(self): + return pyro.sample( + "propensity_weights", + dist.Normal(0.0, self.prior_scale).expand((self.p,)).to_event(1), + ) + + def sample_treatment_weight(self): + return pyro.sample("treatment_weight", dist.Normal(0.0, 1.0)) + + def sample_covariate_loc_scale(self): + loc = pyro.sample( + "covariate_loc", dist.Normal(0.0, 1.0).expand((self.p,)).to_event(1) + ) + scale = pyro.sample( + "covariate_scale", dist.LogNormal(0, 1).expand((self.p,)).to_event(1) + ) + return loc, scale + + def forward(self, N: int = 1): + intercept = self.sample_intercept() + outcome_weights = self.sample_outcome_weights() + propensity_weights = self.sample_propensity_weights() + tau = self.sample_treatment_weight() + x_loc, x_scale = self.sample_covariate_loc_scale() + with pyro.plate("obs", N, dim=-1): + X = pyro.sample("X", dist.Normal(x_loc, x_scale).to_event(1)) + A = pyro.sample( + "A", + dist.Bernoulli( + logits=torch.einsum("...np,...p->...n", X, propensity_weights) + ), + ) + return pyro.sample( + "Y", + self.link_fn( + torch.einsum("...np,...p->...n", X, outcome_weights) + + A * tau + + intercept + ), + ) + + +class KnownCovariateDistModel(HighDimLinearModel): + def sample_covariate_loc_scale(self): + return torch.zeros(self.p), torch.ones(self.p) + + +class BenchmarkLinearModel(HighDimLinearModel): + def __init__( + self, + p: int, + link_fn: Callable[..., dist.Distribution], + alpha: int, + beta: int, + treatment_weight: float = 0.0, + ): + super().__init__(p, link_fn) + self.alpha = alpha # sparsity of propensity weights + self.beta = beta # sparisty of outcome weights + self.treatment_weight = treatment_weight + + def sample_outcome_weights(self): + outcome_weights = 1 / math.sqrt(self.beta) * torch.ones(self.p) + outcome_weights[self.beta :] = 0.0 + return outcome_weights + + def sample_treatment_null_weight(self): + return torch.tensor(0.0) + + def sample_propensity_weights(self): + propensity_weights = 1 / math.sqrt(self.alpha) * torch.ones(self.p) + propensity_weights[self.alpha :] = 0.0 + return propensity_weights + + def sample_treatment_weight(self): + return torch.tensor(self.treatment_weight) + + def sample_intercept(self): + return torch.tensor(0.0) + + def sample_covariate_loc_scale(self): + return torch.zeros(self.p), torch.ones(self.p) + + +class MLEGuide(torch.nn.Module): + def __init__(self, mle_est: ParamDict): + super().__init__() + self.names = list(mle_est.keys()) + for name, value in mle_est.items(): + setattr(self, name + "_param", torch.nn.Parameter(value)) + + def forward(self, *args, **kwargs): + for name in self.names: + value = getattr(self, name + "_param") + pyro.sample(name, dist.Delta(value)) + + +class ATETestPoint(TypedDict): + X: torch.Tensor + A: torch.Tensor + Y: torch.Tensor + + +class ATEParamDict(TypedDict): + propensity_weights: torch.Tensor + outcome_weights: torch.Tensor + treatment_weight: torch.Tensor + intercept: torch.Tensor + + +def closed_form_ate_correction( + X_test: ATETestPoint, theta: ATEParamDict +) -> Tuple[torch.Tensor, torch.Tensor]: + X = X_test["X"] + A = X_test["A"] + Y = X_test["Y"] + pi_X = torch.sigmoid(X.mv(theta["propensity_weights"])) + mu_X = ( + X.mv(theta["outcome_weights"]) + + A * theta["treatment_weight"] + + theta["intercept"] + ) + analytic_eif_at_test_pts = (A / pi_X - (1 - A) / (1 - pi_X)) * (Y - mu_X) + analytic_correction = analytic_eif_at_test_pts.mean() + return analytic_correction, analytic_eif_at_test_pts diff --git a/tests/robust/test_handlers.py b/tests/robust/test_handlers.py new file mode 100644 index 000000000..e43015282 --- /dev/null +++ b/tests/robust/test_handlers.py @@ -0,0 +1,85 @@ +import functools +from typing import Callable, List, Mapping, Optional, Set, Tuple, TypeVar + +import pyro +import pytest +import torch +from typing_extensions import ParamSpec + +from chirho.robust.handlers.estimators import one_step_corrected_estimator +from chirho.robust.handlers.predictive import PredictiveFunctional, PredictiveModel + +from .robust_fixtures import SimpleGuide, SimpleModel + +pyro.settings.set(module_local_params=True) + +P = ParamSpec("P") +Q = ParamSpec("Q") +S = TypeVar("S") +T = TypeVar("T") + + +ModelTestCase = Tuple[ + Callable[[], Callable], Callable[[Callable], Callable], Set[str], Optional[int] +] + +MODEL_TEST_CASES: List[ModelTestCase] = [ + (SimpleModel, lambda _: SimpleGuide(), {"y"}, 1), + (SimpleModel, lambda _: SimpleGuide(), {"y"}, None), + pytest.param( + SimpleModel, + pyro.infer.autoguide.AutoNormal, + {"y"}, + 1, + marks=pytest.mark.xfail( + reason="torch.func autograd doesnt work with PyroParam" + ), + ), +] + + +@pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", MODEL_TEST_CASES) +@pytest.mark.parametrize("num_samples_outer,num_samples_inner", [(10, None), (10, 100)]) +@pytest.mark.parametrize("cg_iters", [None, 1, 10]) +@pytest.mark.parametrize("num_predictive_samples", [1, 5]) +@pytest.mark.parametrize("estimation_method", [one_step_corrected_estimator]) +def test_estimator_smoke( + model, + guide, + obs_names, + max_plate_nesting, + num_samples_outer, + num_samples_inner, + cg_iters, + num_predictive_samples, + estimation_method, +): + model = model() + guide = guide(model) + model(), guide() # initialize + + with torch.no_grad(): + test_datum = { + k: v[0] + for k, v in pyro.infer.Predictive( + model, num_samples=2, return_sites=obs_names, parallel=True + )().items() + } + + estimator = estimation_method( + functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), + test_datum, + max_plate_nesting=max_plate_nesting, + num_samples_outer=num_samples_outer, + num_samples_inner=num_samples_inner, + cg_iters=cg_iters, + )(PredictiveModel(model, guide)) + + estimate_on_test: Mapping[str, torch.Tensor] = estimator() + assert len(estimate_on_test) > 0 + for k, v in estimate_on_test.items(): + assert not torch.isnan(v).any(), f"{estimation_method} for {k} had nans" + assert not torch.isinf(v).any(), f"{estimation_method} for {k} had infs" + assert not torch.isclose( + v, torch.zeros_like(v) + ).all(), f"{estimation_method} estimator for {k} was zero" diff --git a/tests/robust/test_internals_compositions.py b/tests/robust/test_internals_compositions.py new file mode 100644 index 000000000..6f9dfde8d --- /dev/null +++ b/tests/robust/test_internals_compositions.py @@ -0,0 +1,226 @@ +import functools +import warnings + +import pyro +import pytest +import torch + +from chirho.indexed.handlers import IndexPlatesMessenger +from chirho.indexed.ops import indices_of +from chirho.robust.handlers.predictive import PredictiveModel +from chirho.robust.internals.linearize import ( + conjugate_gradient_solve, + make_empirical_fisher_vp, +) +from chirho.robust.internals.nmc import ( + BatchedLatents, + BatchedNMCLogMarginalLikelihood, + BatchedObservations, +) +from chirho.robust.internals.utils import make_functional_call, reset_rng_state + +from .robust_fixtures import SimpleGuide, SimpleModel + +pyro.settings.set(module_local_params=True) + + +def test_empirical_fisher_vp_nmclikelihood_cg_composition(): + model = SimpleModel() + guide = SimpleGuide() + model(), guide() # initialize + log_prob = BatchedNMCLogMarginalLikelihood( + PredictiveModel(model, guide), num_samples=100 + ) + log_prob_params, func_log_prob = make_functional_call(log_prob) + func_log_prob = reset_rng_state(pyro.util.get_rng_state())(func_log_prob) + + predictive = pyro.infer.Predictive( + model, guide=guide, num_samples=1000, parallel=True, return_sites=["y"] + ) + predictive_params, func_predictive = make_functional_call(predictive) + + cg_solver = functools.partial(conjugate_gradient_solve, cg_iters=2) + + with torch.no_grad(): + data = func_predictive(predictive_params) + + fvp = torch.func.vmap( + make_empirical_fisher_vp(func_log_prob, log_prob_params, data) + ) + + v = { + k: torch.ones_like(v).unsqueeze(0) + if k != "model.guide.loc_a" + else torch.zeros_like(v).unsqueeze(0) + for k, v in log_prob_params.items() + } + + # For this model, fvp for loc_a is zero. See + # https://github.com/BasisResearch/chirho/issues/427 + assert fvp(v)["model.guide.loc_a"].abs().max() == 0 + assert all(fvp_vk.shape == v[k].shape for k, fvp_vk in fvp(v).items()) + + solve_one = cg_solver(fvp, v) + solve_two = cg_solver(fvp, v) + + if solve_one["model.guide.loc_a"].abs().max() > 1e6: + warnings.warn( + "solve_one['guide.loc_a'] is large (max entry={}).".format( + solve_one["model.guide.loc_a"].abs().max() + ) + ) + + if solve_one["model.guide.loc_b"].abs().max() > 1e6: + warnings.warn( + "solve_one['guide.loc_b'] is large (max entry={}).".format( + solve_one["model.guide.loc_b"].abs().max() + ) + ) + + assert torch.allclose( + solve_one["model.guide.loc_a"], + torch.zeros_like(log_prob_params["model.guide.loc_a"]), + ) + assert torch.allclose( + solve_one["model.guide.loc_a"], solve_two["model.guide.loc_a"] + ) + assert torch.allclose( + solve_one["model.guide.loc_b"], solve_two["model.guide.loc_b"] + ) + + +link_functions = [ + lambda mu: pyro.distributions.Normal(mu, 1.0), + lambda mu: pyro.distributions.Bernoulli(logits=mu), + lambda mu: pyro.distributions.Beta(concentration1=mu, concentration0=1.0), + lambda mu: pyro.distributions.Exponential(rate=mu), +] + + +@pytest.mark.parametrize("link_fn", link_functions) +def test_nmc_likelihood_seeded(link_fn): + model = SimpleModel(link_fn=link_fn) + guide = SimpleGuide() + model(), guide() # initialize + + log_prob = BatchedNMCLogMarginalLikelihood( + PredictiveModel(model, guide), num_samples=3, max_plate_nesting=3 + ) + log_prob_params, func_log_prob = make_functional_call(log_prob) + + func_log_prob = reset_rng_state(pyro.util.get_rng_state())(func_log_prob) + + datapoint = {"y": torch.tensor([1.0, 2.0, 3.0])} + prob_call_one = func_log_prob(log_prob_params, datapoint) + prob_call_two = func_log_prob(log_prob_params, datapoint) + prob_call_three = func_log_prob(log_prob_params, datapoint) + assert torch.allclose(prob_call_two, prob_call_three) + assert torch.allclose(prob_call_one, prob_call_two) + + data = {"y": torch.tensor([[0.3665, 1.5440, 2.2210], [0.3665, 1.5440, 2.2210]])} + + fvp = make_empirical_fisher_vp(func_log_prob, log_prob_params, data) + + v = {k: torch.ones_like(v) for k, v in log_prob_params.items()} + + assert ( + fvp(v)["model.guide.loc_a"].abs().max() + + fvp(v)["model.guide.loc_b"].abs().max() + ) > 0 + + # Check if fvp agrees across multiple calls of same `fvp` object + assert torch.allclose(fvp(v)["model.guide.loc_a"], fvp(v)["model.guide.loc_a"]) + assert torch.allclose(fvp(v)["model.guide.loc_b"], fvp(v)["model.guide.loc_b"]) + + +@pytest.mark.parametrize("pad_dim", [0, 1, 2]) +def test_batched_observations(pad_dim: int): + max_plate_nesting = 1 + pad_dim + obs_plate_name = "__dummy_plate__" + num_particles_obs = 3 + model = SimpleModel() + guide = SimpleGuide() + + model(), guide() # initialize + + predictive = pyro.infer.Predictive( + model, num_samples=num_particles_obs, return_sites=["y"] + ) + + test_data = predictive() + + with IndexPlatesMessenger(first_available_dim=-max_plate_nesting - 1): + with pyro.poutine.trace() as tr: + with BatchedObservations(test_data, name=obs_plate_name): + model() + + tr.trace.compute_log_prob() + + for name, node in tr.trace.nodes.items(): + if node["type"] == "sample" and not pyro.poutine.util.site_is_subsample( + node + ): + if name in test_data: + assert obs_plate_name in indices_of(node["log_prob"], event_dim=0) + assert obs_plate_name in indices_of( + node["value"], event_dim=len(node["fn"].event_shape) + ) + else: + assert obs_plate_name not in indices_of( + node["log_prob"], event_dim=0 + ) + assert obs_plate_name not in indices_of( + node["value"], event_dim=len(node["fn"].event_shape) + ) + + +@pytest.mark.parametrize("pad_dim", [0, 1, 2]) +def test_batched_latents_observations(pad_dim: int): + max_plate_nesting = 1 + pad_dim + num_particles_latent = 5 + num_particles_obs = 3 + obs_plate_name = "__dummy_plate__" + latent_plate_name = "__dummy_latents__" + model = SimpleModel() + guide = SimpleGuide() + + model(), guide() # initialize + + predictive = pyro.infer.Predictive( + model, num_samples=num_particles_obs, return_sites=["y"] + ) + + test_data = predictive() + + with IndexPlatesMessenger(first_available_dim=-max_plate_nesting - 1): + with pyro.poutine.trace() as tr: + with BatchedLatents( + num_particles=num_particles_latent, name=latent_plate_name + ): + with BatchedObservations(test_data, name=obs_plate_name): + model() + + tr.trace.compute_log_prob() + + for name, node in tr.trace.nodes.items(): + if node["type"] == "sample" and not pyro.poutine.util.site_is_subsample( + node + ): + if name in test_data: + assert obs_plate_name in indices_of(node["log_prob"], event_dim=0) + assert obs_plate_name in indices_of( + node["value"], event_dim=len(node["fn"].event_shape) + ) + assert latent_plate_name in indices_of( + node["log_prob"], event_dim=0 + ) + assert latent_plate_name not in indices_of( + node["value"], event_dim=len(node["fn"].event_shape) + ) + else: + assert latent_plate_name in indices_of( + node["log_prob"], event_dim=0 + ) + assert latent_plate_name in indices_of( + node["value"], event_dim=len(node["fn"].event_shape) + ) diff --git a/tests/robust/test_internals_linearize.py b/tests/robust/test_internals_linearize.py new file mode 100644 index 000000000..435632789 --- /dev/null +++ b/tests/robust/test_internals_linearize.py @@ -0,0 +1,359 @@ +import functools +from typing import Callable, List, Mapping, Optional, Set, Tuple, TypeVar + +import pyro +import pyro.distributions as dist +import pytest +import torch +from pyro.infer.predictive import Predictive +from typing_extensions import ParamSpec + +from chirho.robust.handlers.predictive import PredictiveModel +from chirho.robust.internals.linearize import ( + conjugate_gradient_solve, + linearize, + make_empirical_fisher_vp, +) + +from .robust_fixtures import ( + BenchmarkLinearModel, + DataConditionedModel, + GaussianModel, + KnownCovariateDistModel, + MLEGuide, + SimpleGuide, + SimpleModel, + closed_form_ate_correction, + gaussian_log_prob, +) + +pyro.settings.set(module_local_params=True) + +P = ParamSpec("P") +Q = ParamSpec("Q") +S = TypeVar("S") +T = TypeVar("T") + + +@pytest.mark.parametrize("ndim", [1, 2, 3, 10]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +@pytest.mark.parametrize("num_particles", [1, 4]) +def test_batch_cg_solve(ndim: int, dtype: torch.dtype, num_particles: int): + cg_iters = None + residual_tol = 1e-10 + + U = torch.rand(ndim, ndim, dtype=dtype) + A = torch.eye(ndim, dtype=dtype) + 0.1 * U.mm(U.t()) + expected_x = torch.randn(num_particles, ndim, dtype=dtype) + b = torch.einsum("ij,nj->ni", A, expected_x) + assert b.shape == (num_particles, ndim) + + batch_solve = functools.partial( + conjugate_gradient_solve, + lambda v: torch.einsum("ij,nj->ni", A, v), + cg_iters=cg_iters, + residual_tol=residual_tol, + ) + + actual_x = batch_solve(b) + + assert torch.all(torch.sum((actual_x - expected_x) ** 2, dim=1) < 1e-4) + + +ModelTestCase = Tuple[ + Callable[[], Callable], Callable[[Callable], Callable], Set[str], Optional[int] +] + +MODEL_TEST_CASES: List[ModelTestCase] = [ + (SimpleModel, lambda _: SimpleGuide(), {"y"}, 1), + (SimpleModel, lambda _: SimpleGuide(), {"y"}, None), + pytest.param( + SimpleModel, + pyro.infer.autoguide.AutoNormal, + {"y"}, + 1, + marks=pytest.mark.xfail( + reason="torch.func autograd doesnt work with PyroParam" + ), + ), +] + + +@pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", MODEL_TEST_CASES) +@pytest.mark.parametrize("num_samples_outer,num_samples_inner", [(10, None), (10, 100)]) +@pytest.mark.parametrize("cg_iters", [None, 1, 10]) +def test_nmc_param_influence_smoke( + model, + guide, + obs_names, + max_plate_nesting, + num_samples_outer, + num_samples_inner, + cg_iters, +): + model = model() + guide = guide(model) + + model(), guide() # initialize + + param_eif = linearize( + PredictiveModel(model, guide), + max_plate_nesting=max_plate_nesting, + num_samples_outer=num_samples_outer, + num_samples_inner=num_samples_inner, + cg_iters=cg_iters, + ) + + with torch.no_grad(): + test_datum = { + k: v[0] + for k, v in pyro.infer.Predictive( + model, num_samples=2, return_sites=obs_names, parallel=True + )().items() + } + + test_datum_eif: Mapping[str, torch.Tensor] = param_eif(test_datum) + assert len(test_datum_eif) > 0 + for k, v in test_datum_eif.items(): + assert not torch.isnan(v).any(), f"eif for {k} had nans" + assert not torch.isinf(v).any(), f"eif for {k} had infs" + if not k.endswith("guide.loc_a"): + assert not torch.isclose( + v, torch.zeros_like(v) + ).all(), f"eif for {k} was zero" + else: + assert torch.isclose( + v, torch.zeros_like(v) + ).all(), f"eif for {k} should be zero" + + +@pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", MODEL_TEST_CASES) +@pytest.mark.parametrize("num_samples_outer,num_samples_inner", [(10, None), (10, 100)]) +@pytest.mark.parametrize("cg_iters", [None, 1, 10]) +def test_nmc_param_influence_vmap_smoke( + model, + guide, + obs_names, + max_plate_nesting, + num_samples_outer, + num_samples_inner, + cg_iters, +): + model = model() + guide = guide(model) + + model(), guide() # initialize + + param_eif = linearize( + PredictiveModel(model, guide), + max_plate_nesting=max_plate_nesting, + num_samples_outer=num_samples_outer, + num_samples_inner=num_samples_inner, + cg_iters=cg_iters, + ) + + with torch.no_grad(): + test_data = pyro.infer.Predictive( + model, num_samples=4, return_sites=obs_names, parallel=True + )() + + test_data_eif: Mapping[str, torch.Tensor] = param_eif(test_data) + assert len(test_data_eif) > 0 + for k, v in test_data_eif.items(): + assert not torch.isnan(v).any(), f"eif for {k} had nans" + assert not torch.isinf(v).any(), f"eif for {k} had infs" + if not k.endswith("guide.loc_a"): + assert not torch.isclose( + v, torch.zeros_like(v) + ).all(), f"eif for {k} was zero" + else: + assert torch.isclose( + v, torch.zeros_like(v) + ).all(), f"eif for {k} should be zero" + + +@pytest.mark.parametrize( + "loc", [torch.zeros(2, requires_grad=True), torch.ones(2, requires_grad=True)] +) +@pytest.mark.parametrize( + "cov_mat", + [ + torch.eye(2, requires_grad=False), + torch.tensor(torch.ones(2, 2) + torch.eye(2), requires_grad=False), + ], +) +@pytest.mark.parametrize( + "v", + [ + torch.tensor([1.0, 0.0], requires_grad=False), + torch.tensor([0.0, 1.0], requires_grad=False), + torch.tensor([1.0, 1.0], requires_grad=False), + torch.tensor([0.0, 0.0], requires_grad=False), + ], +) +def test_empirical_fisher_vp_against_analytical( + loc: torch.Tensor, cov_mat: torch.Tensor, v: torch.Tensor +): + func_log_prob = gaussian_log_prob + log_prob_params = {"loc": loc} + N_monte_carlo = 10000 + data = Predictive(GaussianModel(cov_mat), num_samples=N_monte_carlo)(loc) + empirical_fisher_vp_func = make_empirical_fisher_vp( + func_log_prob, log_prob_params, data, cov_mat=cov_mat + ) + + empirical_fisher_vp = empirical_fisher_vp_func({"loc": v})["loc"] + + prec_matrix = torch.linalg.inv(cov_mat) + true_vp = prec_matrix.mv(v) + + assert torch.all(torch.isclose(empirical_fisher_vp, true_vp, atol=0.1)) + + +@pytest.mark.parametrize( + "data_config", + [ + (torch.zeros(1, requires_grad=True), torch.eye(1)), + (torch.ones(2, requires_grad=True), torch.eye(2)), + ], +) +def test_fisher_vmap(data_config): + loc, cov_mat = data_config + func_log_prob = gaussian_log_prob + log_prob_params = {"loc": loc} + N_monte_carlo = 10000 + data = Predictive(GaussianModel(cov_mat), num_samples=N_monte_carlo)(loc) + empirical_fisher_vp_func = make_empirical_fisher_vp( + func_log_prob, log_prob_params, data, cov_mat=cov_mat + ) + v_single_one = torch.ones(cov_mat.shape[1]) + v_single_two = 0.4 * torch.ones(cov_mat.shape[1]) + v_batch = torch.stack([v_single_one, v_single_two], axis=0) + empirical_fisher_vp_func_batched = torch.func.vmap(empirical_fisher_vp_func) + + # Check if fisher vector product works on a single vector and a batch of vectors + single_one_out = empirical_fisher_vp_func({"loc": v_single_one}) + single_two_out = empirical_fisher_vp_func({"loc": v_single_two}) + batch_out = empirical_fisher_vp_func_batched({"loc": v_batch}) + + assert torch.allclose(batch_out["loc"][0], single_one_out["loc"]) + assert torch.allclose(batch_out["loc"][1], single_two_out["loc"]) + + with pytest.raises(RuntimeError): + # Fisher vector product should not work on a batch of vectors + empirical_fisher_vp_func({"loc": v_batch}) + with pytest.raises(RuntimeError): + # Batched Fisher vector product should not work on a single vector + empirical_fisher_vp_func_batched({"loc": v_single_one}) + + +@pytest.mark.parametrize( + "data_config", + [ + (torch.zeros(1, requires_grad=True), torch.eye(1)), + (torch.ones(2, requires_grad=True), torch.eye(2)), + ], +) +def test_fisher_grad_smoke(data_config): + loc, cov_mat = data_config + func_log_prob = gaussian_log_prob + log_prob_params = {"loc": loc} + N_monte_carlo = 10000 + data = Predictive(GaussianModel(cov_mat), num_samples=N_monte_carlo)(loc) + empirical_fisher_vp_func = make_empirical_fisher_vp( + func_log_prob, log_prob_params, data, cov_mat=cov_mat + ) + + v = 0.5 * torch.ones(cov_mat.shape[1], requires_grad=True) + + def f(x): + return empirical_fisher_vp_func({"loc": x})["loc"].sum() + + # Check using `torch.func.grad` + assert ( + torch.func.grad(f)(v).sum() != 0 + ), "Zero gradients but expected non-zero gradients" + + # Check using autograd + assert torch.autograd.gradcheck( + f, v, atol=0.2 + ), "Finite difference gradients do not match autograd gradients" + + +def test_linearize_against_analytic_ate(): + p = 1 + alpha = 1 + beta = 1 + N_train = 100 + N_test = 100 + + def link(mu): + return dist.Normal(mu, 1.0) + + # Generate data + benchmark_model = BenchmarkLinearModel(p, link, alpha, beta) + D_train = Predictive( + benchmark_model, num_samples=N_train, return_sites=["X", "A", "Y"] + )() + D_train = {k: v.squeeze(-1) for k, v in D_train.items()} + D_test = Predictive( + benchmark_model, num_samples=N_test, return_sites=["X", "A", "Y"] + )() + D_test_flat = {k: v.squeeze(-1) for k, v in D_test.items()} + + model = KnownCovariateDistModel(p, link) + conditioned_model = DataConditionedModel(model) + guide_train = pyro.infer.autoguide.AutoDelta(conditioned_model) + elbo = pyro.infer.Trace_ELBO()(conditioned_model, guide_train) + + # initialize parameters + elbo(D_train) + + adam = torch.optim.Adam(elbo.parameters(), lr=0.03) + + # Do gradient steps + for _ in range(500): + adam.zero_grad() + loss = elbo(D_train) + loss.backward() + adam.step() + + theta_hat = { + k: v.clone().detach().requires_grad_(True) for k, v in guide_train().items() + } + _, analytic_eif_at_test_pts = closed_form_ate_correction(D_test_flat, theta_hat) + + mle_guide = MLEGuide(theta_hat) + param_eif = linearize( + PredictiveModel(model, mle_guide), + num_samples_outer=10000, + num_samples_inner=1, + cg_iters=4, # dimension of params = 4 + pointwise_influence=True, + ) + + test_data_eif = param_eif(D_test) + median_abs_error = torch.abs( + test_data_eif["model.guide.treatment_weight_param"] - analytic_eif_at_test_pts + ).median() + median_scale = torch.abs(analytic_eif_at_test_pts).median() + if median_scale > 1: + assert median_abs_error / median_scale < 0.5 + else: + assert median_abs_error < 0.5 + + # Test w/ pointwise_influence=False + param_eif = linearize( + PredictiveModel(model, mle_guide), + num_samples_outer=10000, + num_samples_inner=1, + cg_iters=4, # dimension of params = 4 + pointwise_influence=False, + ) + + test_data_eif = param_eif(D_test) + assert torch.allclose( + test_data_eif["model.guide.treatment_weight_param"][0], + analytic_eif_at_test_pts.mean(), + atol=0.5, + ) diff --git a/tests/robust/test_ops.py b/tests/robust/test_ops.py new file mode 100644 index 000000000..e3d5e5290 --- /dev/null +++ b/tests/robust/test_ops.py @@ -0,0 +1,122 @@ +import functools +from typing import Callable, List, Mapping, Optional, Set, Tuple, TypeVar + +import pyro +import pytest +import torch +from typing_extensions import ParamSpec + +from chirho.robust.handlers.predictive import PredictiveFunctional, PredictiveModel +from chirho.robust.ops import influence_fn + +from .robust_fixtures import SimpleGuide, SimpleModel + +pyro.settings.set(module_local_params=True) + +P = ParamSpec("P") +Q = ParamSpec("Q") +S = TypeVar("S") +T = TypeVar("T") + + +ModelTestCase = Tuple[ + Callable[[], Callable], Callable[[Callable], Callable], Set[str], Optional[int] +] + +MODEL_TEST_CASES: List[ModelTestCase] = [ + (SimpleModel, lambda _: SimpleGuide(), {"y"}, 1), + (SimpleModel, lambda _: SimpleGuide(), {"y"}, None), + pytest.param( + SimpleModel, + pyro.infer.autoguide.AutoNormal, + {"y"}, + 1, + marks=pytest.mark.xfail( + reason="torch.func autograd doesnt work with PyroParam" + ), + ), +] + + +@pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", MODEL_TEST_CASES) +@pytest.mark.parametrize("num_samples_outer,num_samples_inner", [(10, None), (10, 100)]) +@pytest.mark.parametrize("cg_iters", [None, 1, 10]) +@pytest.mark.parametrize("num_predictive_samples", [1, 5]) +def test_nmc_predictive_influence_smoke( + model, + guide, + obs_names, + max_plate_nesting, + num_samples_outer, + num_samples_inner, + cg_iters, + num_predictive_samples, +): + model = model() + guide = guide(model) + model(), guide() # initialize + + with torch.no_grad(): + test_datum = { + k: v[0] + for k, v in pyro.infer.Predictive( + model, num_samples=2, return_sites=obs_names, parallel=True + )().items() + } + + predictive_eif = influence_fn( + functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), + test_datum, + max_plate_nesting=max_plate_nesting, + num_samples_outer=num_samples_outer, + num_samples_inner=num_samples_inner, + cg_iters=cg_iters, + )(PredictiveModel(model, guide)) + + test_datum_eif: Mapping[str, torch.Tensor] = predictive_eif() + assert len(test_datum_eif) > 0 + for k, v in test_datum_eif.items(): + assert not torch.isnan(v).any(), f"eif for {k} had nans" + assert not torch.isinf(v).any(), f"eif for {k} had infs" + assert not torch.isclose(v, torch.zeros_like(v)).all(), f"eif for {k} was zero" + + +@pytest.mark.parametrize("model,guide,obs_names,max_plate_nesting", MODEL_TEST_CASES) +@pytest.mark.parametrize("num_samples_outer,num_samples_inner", [(10, None), (10, 100)]) +@pytest.mark.parametrize("cg_iters", [None, 1, 10]) +@pytest.mark.parametrize("num_predictive_samples", [1, 5]) +def test_nmc_predictive_influence_vmap_smoke( + model, + guide, + obs_names, + max_plate_nesting, + num_samples_outer, + num_samples_inner, + cg_iters, + num_predictive_samples, +): + model = model() + guide = guide(model) + + model(), guide() # initialize + + with torch.no_grad(): + test_data = pyro.infer.Predictive( + model, num_samples=4, return_sites=obs_names, parallel=True + )() + + predictive_eif = influence_fn( + functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), + test_data, + max_plate_nesting=max_plate_nesting, + num_samples_outer=num_samples_outer, + num_samples_inner=num_samples_inner, + cg_iters=cg_iters, + )(PredictiveModel(model, guide)) + + test_data_eif: Mapping[str, torch.Tensor] = predictive_eif() + assert len(test_data_eif) > 0 + for k, v in test_data_eif.items(): + assert not torch.isnan(v).any(), f"eif for {k} had nans" + assert not torch.isinf(v).any(), f"eif for {k} had infs" + assert not torch.isclose(v, torch.zeros_like(v)).all(), f"eif for {k} was zero" diff --git a/tests/robust/test_performance.py b/tests/robust/test_performance.py new file mode 100644 index 000000000..34d5e4d02 --- /dev/null +++ b/tests/robust/test_performance.py @@ -0,0 +1,180 @@ +import math +import time +import warnings +from functools import partial +from typing import Any, Callable, Container, Generic, Optional, TypeVar + +import pyro +import pytest +import torch +from typing_extensions import ParamSpec + +from chirho.indexed.handlers import DependentMaskMessenger +from chirho.observational.handlers import condition +from chirho.robust.handlers.predictive import PredictiveModel +from chirho.robust.internals.linearize import make_empirical_fisher_vp +from chirho.robust.internals.nmc import BatchedNMCLogMarginalLikelihood +from chirho.robust.internals.utils import guess_max_plate_nesting, make_functional_call +from chirho.robust.ops import Point + +from .robust_fixtures import SimpleGuide, SimpleModel + +pyro.settings.set(module_local_params=True) + +P = ParamSpec("P") +Q = ParamSpec("Q") +S = TypeVar("S") +T = TypeVar("T") + + +class _UnmaskNamedSites(DependentMaskMessenger): + names: Container[str] + + def __init__(self, names: Container[str]): + self.names = names + + def get_mask( + self, + dist: pyro.distributions.Distribution, + value: Optional[torch.Tensor], + device: torch.device = torch.device("cpu"), + name: Optional[str] = None, + ) -> torch.Tensor: + return torch.tensor(name is None or name in self.names, device=device) + + +class OldNMCLogPredictiveLikelihood(Generic[P, T], torch.nn.Module): + model: Callable[P, Any] + guide: Callable[P, Any] + num_samples: int + max_plate_nesting: Optional[int] + + def __init__( + self, + model: torch.nn.Module, + guide: torch.nn.Module, + *, + num_samples: int = 1, + max_plate_nesting: Optional[int] = None, + ): + super().__init__() + self.model = model + self.guide = guide + self.num_samples = num_samples + self.max_plate_nesting = max_plate_nesting + + def forward( + self, data: Point[T], *args: P.args, **kwargs: P.kwargs + ) -> torch.Tensor: + if self.max_plate_nesting is None: + self.max_plate_nesting = guess_max_plate_nesting( + self.model, self.guide, *args, **kwargs + ) + warnings.warn( + "Since max_plate_nesting is not specified, \ + the first call to NMCLogPredictiveLikelihood will not be seeded properly. \ + See https://github.com/BasisResearch/chirho/pull/408" + ) + + masked_guide = pyro.poutine.mask(mask=False)(self.guide) + masked_model = _UnmaskNamedSites(names=set(data.keys()))( + condition(data=data)(self.model) + ) + log_weights = pyro.infer.importance.vectorized_importance_weights( + masked_model, + masked_guide, + *args, + num_samples=self.num_samples, + max_plate_nesting=self.max_plate_nesting, + **kwargs, + )[0] + return torch.logsumexp(log_weights, dim=0) - math.log(self.num_samples) + + +class SimpleMultivariateGaussianModel(pyro.nn.PyroModule): + def __init__(self, p): + super().__init__() + self.p = p + + def forward(self): + loc = pyro.sample( + "loc", pyro.distributions.Normal(torch.zeros(self.p), 1.0).to_event(1) + ) + cov_mat = torch.eye(self.p) + return pyro.sample("y", pyro.distributions.MultivariateNormal(loc, cov_mat)) + + +class SimpleMultivariateGuide(torch.nn.Module): + def __init__(self, p): + super().__init__() + self.loc_ = torch.nn.Parameter(torch.rand((p,))) + self.p = p + + def forward(self): + return pyro.sample("loc", pyro.distributions.Normal(self.loc_, 1).to_event(1)) + + +model_guide_types = [ + ( + partial(SimpleMultivariateGaussianModel, p=500), + partial(SimpleMultivariateGuide, p=500), + ), + (SimpleModel, SimpleGuide), +] + + +@pytest.mark.skip(reason="This test is too slow to run on CI") +@pytest.mark.parametrize("model_guide", model_guide_types) +def test_empirical_fisher_vp_performance_with_likelihood(model_guide): + num_monte_carlo = 10000 + model_family, guide_family = model_guide + + model = model_family() + guide = guide_family() + + model() + guide() + + start_time = time.time() + data = pyro.infer.Predictive( + model, guide=guide, num_samples=num_monte_carlo, return_sites=["y"] + )() + end_time = time.time() + print("Data generation time (s): ", end_time - start_time) + + log1_prob_params, func1_log_prob = make_functional_call( + OldNMCLogPredictiveLikelihood(model, guide, max_plate_nesting=1) + ) + batched_func1_log_prob = torch.func.vmap( + func1_log_prob, in_dims=(None, 0), randomness="different" + ) + + log2_prob_params, func2_log_prob = make_functional_call( + BatchedNMCLogMarginalLikelihood(PredictiveModel(model, guide)) + ) + + fisher_hessian_vmapped = make_empirical_fisher_vp( + batched_func1_log_prob, log1_prob_params, data + ) + + fisher_hessian_batched = make_empirical_fisher_vp( + func2_log_prob, log2_prob_params, data + ) + + v1 = { + k: torch.ones_like(v) if k != "guide.loc_a" else torch.zeros_like(v) + for k, v in log1_prob_params.items() + } + v2 = {f"model.{k}": v for k, v in v1.items()} + + func2_log_prob(log2_prob_params, data) + + start_time = time.time() + fisher_hessian_vmapped(v1) + end_time = time.time() + print("Hessian vmapped time (s): ", end_time - start_time) + + start_time = time.time() + fisher_hessian_batched(v2) + end_time = time.time() + print("Hessian manual batched time (s): ", end_time - start_time)