From 6eea60f91366a822c31d1986d74218c10369282b Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 9 Jan 2024 10:51:30 -0500 Subject: [PATCH] Make influence_fn return a Functional --- chirho/robust/handlers/estimators.py | 29 +++---- chirho/robust/internals/linearize.py | 22 ++---- chirho/robust/internals/predictive.py | 17 ++-- chirho/robust/ops.py | 88 ++++++++++----------- tests/robust/test_handlers.py | 25 +++--- tests/robust/test_internals_compositions.py | 7 +- tests/robust/test_internals_linearize.py | 49 ++++++------ tests/robust/test_ops.py | 40 +++++----- tests/robust/test_performance.py | 7 +- 9 files changed, 133 insertions(+), 151 deletions(-) diff --git a/chirho/robust/handlers/estimators.py b/chirho/robust/handlers/estimators.py index 9d2d70f2d..9b3dc6864 100644 --- a/chirho/robust/handlers/estimators.py +++ b/chirho/robust/handlers/estimators.py @@ -1,27 +1,25 @@ -from typing import Any, Callable +from typing import TypeVar -from typing_extensions import Concatenate +from typing_extensions import ParamSpec -from chirho.robust.ops import Functional, P, Point, S, T, influence_fn +from chirho.robust.ops import Functional, Point, influence_fn + +P = ParamSpec("P") +S = TypeVar("S") +T = TypeVar("T") def one_step_correction( - model: Callable[P, Any], - guide: Callable[P, Any], functional: Functional[P, S], + test_data: Point[T], **influence_kwargs, -) -> Callable[Concatenate[Point[T], P], S]: +) -> Functional[P, S]: """ Returns a function that computes the one-step correction for the functional at a specified set of test points as discussed in [1]. - :param model: Python callable containing Pyro primitives. - :type model: Callable[P, Any] - :param guide: Python callable containing Pyro primitives. - :type guide: Callable[P, Any] - :param functional: model summary of interest, which is a function of the - model and guide. + :param functional: model summary of interest :type functional: Functional[P, S] :return: function to compute the one-step correction :rtype: Callable[Concatenate[Point[T], P], S] @@ -33,9 +31,4 @@ def one_step_correction( """ influence_kwargs_one_step = influence_kwargs.copy() influence_kwargs_one_step["pointwise_influence"] = False - eif_fn = influence_fn(model, guide, functional, **influence_kwargs_one_step) - - def _one_step(test_data: Point[T], *args, **kwargs) -> S: - return eif_fn(test_data, *args, **kwargs) - - return _one_step + return influence_fn(functional, test_data, **influence_kwargs_one_step) diff --git a/chirho/robust/internals/linearize.py b/chirho/robust/internals/linearize.py index e4fbdd115..6cbef3caa 100644 --- a/chirho/robust/internals/linearize.py +++ b/chirho/robust/internals/linearize.py @@ -220,7 +220,7 @@ def jvp_fn(log_prob_params: ParamDict) -> torch.Tensor: def linearize( model: Callable[P, Any], - guide: Callable[P, Any], + points: Point[T], *, num_samples_outer: int, num_samples_inner: Optional[int] = None, @@ -228,10 +228,10 @@ def linearize( cg_iters: Optional[int] = None, residual_tol: float = 1e-4, pointwise_influence: bool = True, -) -> Callable[Concatenate[Point[T], P], ParamDict]: +) -> Callable[P, ParamDict]: r""" Returns the influence function associated with the parameters - of ``guide`` and probabilistic program ``model``. This function + of probabilistic program ``model``. This function computes the following quantity at an arbitrary point :math:`x^{\prime}`: .. math:: @@ -248,9 +248,6 @@ def linearize( :param model: Python callable containing Pyro primitives. :type model: Callable[P, Any] - :param guide: Python callable containing Pyro primitives. - Must only contain continuous latent variables. - :type guide: Callable[P, Any] :param num_samples_outer: number of Monte Carlo samples to approximate Fisher information in :func:`make_empirical_fisher_vp` :type num_samples_outer: int @@ -271,7 +268,7 @@ def linearize( over ``points``. Defaults to True. :type pointwise_influence: bool, optional :return: the influence function associated with the parameters - :rtype: Callable[Concatenate[Point[T], P], ParamDict] + :rtype: Callable[P, ParamDict] **Example usage**: @@ -312,13 +309,13 @@ def forward(self): ) points = predictive() influence = linearize( - model, - guide, + PredictiveModel(model, guide), + points, num_samples_outer=1000, num_samples_inner=1000, ) - influence(points) + influence() .. note:: @@ -332,19 +329,17 @@ def forward(self): https://github.com/BasisResearch/chirho/issues/393. """ assert isinstance(model, torch.nn.Module) - assert isinstance(guide, torch.nn.Module) if num_samples_inner is None: num_samples_inner = num_samples_outer**2 predictive = pyro.infer.Predictive( model, - guide=guide, num_samples=num_samples_outer, parallel=True, ) batched_log_prob = BatchedNMCLogPredictiveLikelihood( - model, guide, num_samples=num_samples_inner, max_plate_nesting=max_plate_nesting + model, num_samples=num_samples_inner, max_plate_nesting=max_plate_nesting ) log_prob_params, batched_func_log_prob = make_functional_call(batched_log_prob) log_prob_params_numel: int = sum(p.numel() for p in log_prob_params.values()) @@ -357,7 +352,6 @@ def forward(self): ) def _fn( - points: Point[T], *args: P.args, **kwargs: P.kwargs, ) -> ParamDict: diff --git a/chirho/robust/internals/predictive.py b/chirho/robust/internals/predictive.py index 8b9721f44..cd8a7cee3 100644 --- a/chirho/robust/internals/predictive.py +++ b/chirho/robust/internals/predictive.py @@ -116,12 +116,12 @@ class PredictiveModel(Generic[P, T], torch.nn.Module): """ model: Callable[P, T] - guide: Callable[P, Any] + guide: Optional[Callable[P, Any]] def __init__( self, model: Callable[P, T], - guide: Callable[P, Any], + guide: Optional[Callable[P, Any]] = None, ): super().__init__() self.model = model @@ -135,7 +135,8 @@ def forward(self, *args: P.args, **kwargs: P.kwargs) -> T: :rtype: T """ with pyro.poutine.trace() as guide_tr: - self.guide(*args, **kwargs) + if self.guide is not None: + self.guide(*args, **kwargs) block_guide_sample_sites = pyro.poutine.block( hide=[ @@ -175,13 +176,13 @@ class PredictiveFunctional(Generic[P, T], torch.nn.Module): """ model: Callable[P, Any] - guide: Callable[P, Any] + guide: Optional[Callable[P, Any]] num_samples: int def __init__( self, model: torch.nn.Module, - guide: torch.nn.Module, + guide: Optional[torch.nn.Module] = None, *, num_samples: int = 1, max_plate_nesting: Optional[int] = None, @@ -245,13 +246,13 @@ class BatchedNMCLogPredictiveLikelihood(Generic[P, T], torch.nn.Module): :type num_samples: int, optional """ model: Callable[P, Any] - guide: Callable[P, Any] + guide: Optional[Callable[P, Any]] num_samples: int def __init__( self, model: torch.nn.Module, - guide: torch.nn.Module, + guide: Optional[torch.nn.Module] = None, *, num_samples: int = 1, max_plate_nesting: Optional[int] = None, @@ -279,7 +280,7 @@ def forward( :return: Log predictive likelihood at each datapoint. :rtype: torch.Tensor """ - get_nmc_traces = get_importance_traces(PredictiveModel(self.model, self.guide)) + get_nmc_traces = get_importance_traces(self.model, self.guide) with IndexPlatesMessenger(first_available_dim=self._first_available_dim): with BatchedLatents(self.num_samples, name=self._mc_plate_name): diff --git a/chirho/robust/ops.py b/chirho/robust/ops.py index 34b0dd0f5..0c3efc4b7 100644 --- a/chirho/robust/ops.py +++ b/chirho/robust/ops.py @@ -1,8 +1,7 @@ -import functools from typing import Any, Callable, Mapping, TypeVar import torch -from typing_extensions import Concatenate, ParamSpec +from typing_extensions import ParamSpec from chirho.observational.ops import Observation @@ -12,30 +11,24 @@ T = TypeVar("T") Point = Mapping[str, Observation[T]] -Functional = Callable[[Callable[P, Any], Callable[P, Any]], Callable[P, S]] +Functional = Callable[[Callable[P, Any]], Callable[P, S]] def influence_fn( - model: Callable[P, Any], - guide: Callable[P, Any], - functional: Functional[P, S], - **linearize_kwargs -) -> Callable[Concatenate[Point[T], P], S]: + functional: Functional[P, S], points: Point[T], **linearize_kwargs +) -> Functional[P, S]: """ Returns the efficient influence function for ``functional`` - with respect to the parameters of ``guide`` and probabilistic - program ``model``. + with respect to the parameters of probabilistic program ``model``. :param model: Python callable containing Pyro primitives. :type model: Callable[P, Any] - :param guide: Python callable containing Pyro primitives. - Must only contain continuous latent variables. - :type guide: Callable[P, Any] - :param functional: model summary of interest, which is a function of the - model and guide. + :param functional: model summary of interest, which is a function of the model. :type functional: Functional[P, S] + :param points: points at which to compute the efficient influence function + :type points: Point[T] :return: the efficient influence function for ``functional`` - :rtype: Callable[Concatenate[Point[T], P], S] + :rtype: Functional[P, S] **Example usage**: @@ -95,14 +88,13 @@ def forward(self): ) points = predictive() influence = influence_fn( - model, - guide, SimpleFunctional, + points, num_samples_outer=1000, num_samples_inner=1000, - ) + )(PredictiveModel(model, guide)) - influence(points) + influence() .. note:: @@ -118,31 +110,31 @@ def forward(self): from chirho.robust.internals.linearize import linearize from chirho.robust.internals.utils import make_functional_call - linearized = linearize(model, guide, **linearize_kwargs) - target = functional(model, guide) - - # TODO check that target_params == model_params | guide_params - assert isinstance(target, torch.nn.Module) - target_params, func_target = make_functional_call(target) - - @functools.wraps(target) - def _fn(points: Point[T], *args: P.args, **kwargs: P.kwargs) -> S: - """ - Evaluates the efficient influence function for ``functional`` at each - point in ``points``. - - :param points: points at which to compute the efficient influence function - :type points: Point[T] - :return: efficient influence function evaluated at each point in ``points`` or averaged - :rtype: S - """ - param_eif = linearized(points, *args, **kwargs) - return torch.vmap( - lambda d: torch.func.jvp( - lambda p: func_target(p, *args, **kwargs), (target_params,), (d,) - )[1], - in_dims=0, - randomness="different", - )(param_eif) - - return _fn + def _functional(model: Callable[P, Any]) -> Callable[P, S]: + linearized = linearize(model, points, **linearize_kwargs) + target = functional(model) + + # TODO check that target_params == model_params | guide_params + assert isinstance(target, torch.nn.Module) + target_params, func_target = make_functional_call(target) + + def _fn(*args: P.args, **kwargs: P.kwargs) -> S: + """ + Evaluates the efficient influence function for ``functional`` at each + point in ``points``. + + :return: efficient influence function evaluated at each point in ``points`` or averaged + :rtype: S + """ + param_eif = linearized(*args, **kwargs) + return torch.vmap( + lambda d: torch.func.jvp( + lambda p: func_target(p, *args, **kwargs), (target_params,), (d,) + )[1], + in_dims=0, + randomness="different", + )(param_eif) + + return _fn + + return _functional diff --git a/tests/robust/test_handlers.py b/tests/robust/test_handlers.py index fc36d0b49..5d9566f29 100644 --- a/tests/robust/test_handlers.py +++ b/tests/robust/test_handlers.py @@ -7,7 +7,7 @@ from typing_extensions import ParamSpec from chirho.robust.handlers.estimators import one_step_correction -from chirho.robust.internals.predictive import PredictiveFunctional +from chirho.robust.internals.predictive import PredictiveFunctional, PredictiveModel from .robust_fixtures import SimpleGuide, SimpleModel @@ -56,18 +56,6 @@ def test_one_step_correction_smoke( guide = guide(model) model(), guide() # initialize - one_step = one_step_correction( - model, - guide, - functional=functools.partial( - PredictiveFunctional, num_samples=num_predictive_samples - ), - max_plate_nesting=max_plate_nesting, - num_samples_outer=num_samples_outer, - num_samples_inner=num_samples_inner, - cg_iters=cg_iters, - ) - with torch.no_grad(): test_datum = { k: v[0] @@ -76,7 +64,16 @@ def test_one_step_correction_smoke( )().items() } - one_step_on_test: Mapping[str, torch.Tensor] = one_step(test_datum) + one_step = one_step_correction( + functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), + test_datum, + max_plate_nesting=max_plate_nesting, + num_samples_outer=num_samples_outer, + num_samples_inner=num_samples_inner, + cg_iters=cg_iters, + )(PredictiveModel(model, guide)) + + one_step_on_test: Mapping[str, torch.Tensor] = one_step() assert len(one_step_on_test) > 0 for k, v in one_step_on_test.items(): assert not torch.isnan(v).any(), f"one_step for {k} had nans" diff --git a/tests/robust/test_internals_compositions.py b/tests/robust/test_internals_compositions.py index b9924fab5..97ec951fd 100644 --- a/tests/robust/test_internals_compositions.py +++ b/tests/robust/test_internals_compositions.py @@ -15,6 +15,7 @@ BatchedLatents, BatchedNMCLogPredictiveLikelihood, BatchedObservations, + PredictiveModel, ) from chirho.robust.internals.utils import make_functional_call, reset_rng_state @@ -27,7 +28,9 @@ def test_empirical_fisher_vp_nmclikelihood_cg_composition(): model = SimpleModel() guide = SimpleGuide() model(), guide() # initialize - log_prob = BatchedNMCLogPredictiveLikelihood(model, guide, num_samples=100) + log_prob = BatchedNMCLogPredictiveLikelihood( + PredictiveModel(model, guide), num_samples=100 + ) log_prob_params, func_log_prob = make_functional_call(log_prob) func_log_prob = reset_rng_state(pyro.util.get_rng_state())(func_log_prob) @@ -96,7 +99,7 @@ def test_nmc_likelihood_seeded(link_fn): model(), guide() # initialize log_prob = BatchedNMCLogPredictiveLikelihood( - model, guide, num_samples=3, max_plate_nesting=3 + PredictiveModel(model, guide), num_samples=3, max_plate_nesting=3 ) log_prob_params, func_log_prob = make_functional_call(log_prob) diff --git a/tests/robust/test_internals_linearize.py b/tests/robust/test_internals_linearize.py index a8a80a536..7b86f3648 100644 --- a/tests/robust/test_internals_linearize.py +++ b/tests/robust/test_internals_linearize.py @@ -13,6 +13,7 @@ linearize, make_empirical_fisher_vp, ) +from chirho.robust.internals.predictive import PredictiveModel from .robust_fixtures import ( BenchmarkLinearModel, @@ -95,15 +96,6 @@ def test_nmc_param_influence_smoke( model(), guide() # initialize - param_eif = linearize( - model, - guide, - max_plate_nesting=max_plate_nesting, - num_samples_outer=num_samples_outer, - num_samples_inner=num_samples_inner, - cg_iters=cg_iters, - ) - with torch.no_grad(): test_datum = { k: v[0] @@ -112,7 +104,16 @@ def test_nmc_param_influence_smoke( )().items() } - test_datum_eif: Mapping[str, torch.Tensor] = param_eif(test_datum) + param_eif = linearize( + PredictiveModel(model, guide), + test_datum, + max_plate_nesting=max_plate_nesting, + num_samples_outer=num_samples_outer, + num_samples_inner=num_samples_inner, + cg_iters=cg_iters, + ) + + test_datum_eif: Mapping[str, torch.Tensor] = param_eif() assert len(test_datum_eif) > 0 for k, v in test_datum_eif.items(): assert not torch.isnan(v).any(), f"eif for {k} had nans" @@ -144,21 +145,21 @@ def test_nmc_param_influence_vmap_smoke( model(), guide() # initialize + with torch.no_grad(): + test_data = pyro.infer.Predictive( + model, num_samples=4, return_sites=obs_names, parallel=True + )() + param_eif = linearize( - model, - guide, + PredictiveModel(model, guide), + test_data, max_plate_nesting=max_plate_nesting, num_samples_outer=num_samples_outer, num_samples_inner=num_samples_inner, cg_iters=cg_iters, ) - with torch.no_grad(): - test_data = pyro.infer.Predictive( - model, num_samples=4, return_sites=obs_names, parallel=True - )() - - test_data_eif: Mapping[str, torch.Tensor] = param_eif(test_data) + test_data_eif: Mapping[str, torch.Tensor] = param_eif() assert len(test_data_eif) > 0 for k, v in test_data_eif.items(): assert not torch.isnan(v).any(), f"eif for {k} had nans" @@ -326,15 +327,15 @@ def link(mu): mle_guide = MLEGuide(theta_hat) param_eif = linearize( - model, - mle_guide, + PredictiveModel(model, mle_guide), + D_test, num_samples_outer=10000, num_samples_inner=1, cg_iters=4, # dimension of params = 4 pointwise_influence=True, ) - test_data_eif = param_eif(D_test) + test_data_eif = param_eif() median_abs_error = torch.abs( test_data_eif["guide.treatment_weight_param"] - analytic_eif_at_test_pts ).median() @@ -346,15 +347,15 @@ def link(mu): # Test w/ pointwise_influence=False param_eif = linearize( - model, - mle_guide, + PredictiveModel(model, mle_guide), + D_test, num_samples_outer=10000, num_samples_inner=1, cg_iters=4, # dimension of params = 4 pointwise_influence=False, ) - test_data_eif = param_eif(D_test) + test_data_eif = param_eif() assert torch.allclose( test_data_eif["guide.treatment_weight_param"][0], analytic_eif_at_test_pts.mean(), diff --git a/tests/robust/test_ops.py b/tests/robust/test_ops.py index 3f91377c7..5788e9074 100644 --- a/tests/robust/test_ops.py +++ b/tests/robust/test_ops.py @@ -6,7 +6,7 @@ import torch from typing_extensions import ParamSpec -from chirho.robust.internals.predictive import PredictiveFunctional +from chirho.robust.internals.predictive import PredictiveFunctional, PredictiveModel from chirho.robust.ops import influence_fn from .robust_fixtures import SimpleGuide, SimpleModel @@ -56,16 +56,6 @@ def test_nmc_predictive_influence_smoke( guide = guide(model) model(), guide() # initialize - predictive_eif = influence_fn( - model, - guide, - functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), - max_plate_nesting=max_plate_nesting, - num_samples_outer=num_samples_outer, - num_samples_inner=num_samples_inner, - cg_iters=cg_iters, - ) - with torch.no_grad(): test_datum = { k: v[0] @@ -74,7 +64,16 @@ def test_nmc_predictive_influence_smoke( )().items() } - test_datum_eif: Mapping[str, torch.Tensor] = predictive_eif(test_datum) + predictive_eif = influence_fn( + functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), + test_datum, + max_plate_nesting=max_plate_nesting, + num_samples_outer=num_samples_outer, + num_samples_inner=num_samples_inner, + cg_iters=cg_iters, + )(PredictiveModel(model, guide)) + + test_datum_eif: Mapping[str, torch.Tensor] = predictive_eif() assert len(test_datum_eif) > 0 for k, v in test_datum_eif.items(): assert not torch.isnan(v).any(), f"eif for {k} had nans" @@ -101,22 +100,21 @@ def test_nmc_predictive_influence_vmap_smoke( model(), guide() # initialize + with torch.no_grad(): + test_data = pyro.infer.Predictive( + model, num_samples=4, return_sites=obs_names, parallel=True + )() + predictive_eif = influence_fn( - model, - guide, functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), + test_data, max_plate_nesting=max_plate_nesting, num_samples_outer=num_samples_outer, num_samples_inner=num_samples_inner, cg_iters=cg_iters, - ) - - with torch.no_grad(): - test_data = pyro.infer.Predictive( - model, num_samples=4, return_sites=obs_names, parallel=True - )() + )(PredictiveModel(model, guide)) - test_data_eif: Mapping[str, torch.Tensor] = predictive_eif(test_data) + test_data_eif: Mapping[str, torch.Tensor] = predictive_eif() assert len(test_data_eif) > 0 for k, v in test_data_eif.items(): assert not torch.isnan(v).any(), f"eif for {k} had nans" diff --git a/tests/robust/test_performance.py b/tests/robust/test_performance.py index b1ec08f29..1c07c07b5 100644 --- a/tests/robust/test_performance.py +++ b/tests/robust/test_performance.py @@ -12,7 +12,10 @@ from chirho.indexed.handlers import DependentMaskMessenger from chirho.observational.handlers import condition from chirho.robust.internals.linearize import make_empirical_fisher_vp -from chirho.robust.internals.predictive import BatchedNMCLogPredictiveLikelihood +from chirho.robust.internals.predictive import ( + BatchedNMCLogPredictiveLikelihood, + PredictiveModel, +) from chirho.robust.internals.utils import guess_max_plate_nesting, make_functional_call from chirho.robust.ops import Point @@ -149,7 +152,7 @@ def test_empirical_fisher_vp_performance_with_likelihood(model_guide): ) log2_prob_params, func2_log_prob = make_functional_call( - BatchedNMCLogPredictiveLikelihood(model, guide) + BatchedNMCLogPredictiveLikelihood(PredictiveModel(model, guide)) ) fisher_hessian_vmapped = make_empirical_fisher_vp(