diff --git a/chirho/robust/handlers/estimators.py b/chirho/robust/handlers/estimators.py index 16e8ab227..4f60ddcd6 100644 --- a/chirho/robust/handlers/estimators.py +++ b/chirho/robust/handlers/estimators.py @@ -1,6 +1,6 @@ -from typing import Any, Callable, TypeVar +from typing import TypeVar -from typing_extensions import Concatenate, ParamSpec +from typing_extensions import ParamSpec from chirho.robust.ops import Functional, Point, influence_fn @@ -10,21 +10,17 @@ def one_step_correction( - model: Callable[P, Any], functional: Functional[P, S], + *test_points: Point[T], **influence_kwargs, -) -> Callable[Concatenate[Point[T], P], S]: +) -> Functional[P, S]: """ - Returns a function that computes the one-step correction for the - functional at a specified set of test points as discussed in - [1]. + Returns a functional that computes the one-step correction for the + functional at a specified set of test points as discussed in [1]. - :param model: Python callable containing Pyro primitives. - :type model: Callable[P, Any] - :param functional: model summary of interest, which is a function of the model. - :type functional: Functional[P, S] - :return: function to compute the one-step correction - :rtype: Callable[Concatenate[Point[T], P], S] + :param functional: model summary functional of interest + :param test_points: points at which to compute the one-step correction + :return: functional to compute the one-step correction **References** @@ -33,9 +29,4 @@ def one_step_correction( """ influence_kwargs_one_step = influence_kwargs.copy() influence_kwargs_one_step["pointwise_influence"] = False - eif_fn = influence_fn(model, functional, **influence_kwargs_one_step) - - def _one_step(test_data: Point[T], *args, **kwargs) -> S: - return eif_fn(test_data, *args, **kwargs) - - return _one_step + return influence_fn(functional, *test_points, **influence_kwargs_one_step) diff --git a/chirho/robust/internals/linearize.py b/chirho/robust/internals/linearize.py index 27ce8da39..29447c736 100644 --- a/chirho/robust/internals/linearize.py +++ b/chirho/robust/internals/linearize.py @@ -219,8 +219,7 @@ def jvp_fn(log_prob_params: ParamDict) -> torch.Tensor: def linearize( - model: Callable[P, Any], - *, + *models: Callable[P, Any], num_samples_outer: int, num_samples_inner: Optional[int] = None, max_plate_nesting: Optional[int] = None, @@ -328,6 +327,11 @@ def forward(self): This issue will be addressed in a future release: https://github.com/BasisResearch/chirho/issues/393. """ + if len(models) > 1: + raise NotImplementedError("Only unary version of linearize is implemented.") + else: + (model,) = models + assert isinstance(model, torch.nn.Module) if num_samples_inner is None: num_samples_inner = num_samples_outer**2 diff --git a/chirho/robust/ops.py b/chirho/robust/ops.py index b02bfa47e..86ed8f89d 100644 --- a/chirho/robust/ops.py +++ b/chirho/robust/ops.py @@ -1,33 +1,35 @@ -import functools -from typing import Any, Callable, Mapping, TypeVar +from typing import Any, Callable, Mapping, Protocol, TypeVar import torch -from typing_extensions import Concatenate, ParamSpec +from typing_extensions import ParamSpec from chirho.observational.ops import Observation P = ParamSpec("P") Q = ParamSpec("Q") -S = TypeVar("S") +S = TypeVar("S", covariant=True) T = TypeVar("T") Point = Mapping[str, Observation[T]] -Functional = Callable[[Callable[P, Any]], Callable[P, S]] + + +class Functional(Protocol[P, S]): + def __call__( + self, __model: Callable[P, Any], *models: Callable[P, Any] + ) -> Callable[P, S]: + ... def influence_fn( - model: Callable[P, Any], functional: Functional[P, S], **linearize_kwargs -) -> Callable[Concatenate[Point[T], P], S]: + functional: Functional[P, S], *points: Point[T], **linearize_kwargs +) -> Functional[P, S]: """ - Returns the efficient influence function for ``functional`` - with respect to the parameters of probabilistic program ``model``. + Returns a new functional that computes the efficient influence function for ``functional`` + at the given ``points`` with respect to the parameters of its probabilistic program arguments. - :param model: Python callable containing Pyro primitives. - :type model: Callable[P, Any] :param functional: model summary of interest, which is a function of ``model`` - :type functional: Functional[P, S] - :return: the efficient influence function for ``functional`` - :rtype: Callable[Concatenate[Point[T], P], S] + :param points: points for each input to ``functional`` at which to compute the efficient influence function + :return: functional that computes the efficient influence function for ``functional`` at ``points`` **Example usage**: @@ -88,14 +90,13 @@ def forward(self): ) points = predictive() influence = influence_fn( - model, - guide, SimpleFunctional, + points, num_samples_outer=1000, num_samples_inner=1000, - ) + )(PredictiveModel(model, guide)) - influence(points) + influence() .. note:: @@ -111,31 +112,44 @@ def forward(self): from chirho.robust.internals.linearize import linearize from chirho.robust.internals.utils import make_functional_call - linearized = linearize(model, **linearize_kwargs) - target = functional(model) - - # TODO check that target_params == model_params - assert isinstance(target, torch.nn.Module) - target_params, func_target = make_functional_call(target) + if len(points) != 1: + raise NotImplementedError( + "influence_fn currently only supports unary functionals" + ) - @functools.wraps(target) - def _fn(points: Point[T], *args: P.args, **kwargs: P.kwargs) -> S: + def _influence_functional(*models: Callable[P, Any]) -> Callable[P, S]: """ - Evaluates the efficient influence function for ``functional`` at each - point in ``points``. + Functional representing the efficient influence function of ``functional`` at ``points`` . - :param points: points at which to compute the efficient influence function - :type points: Point[T] - :return: efficient influence function evaluated at each point in ``points`` or averaged - :rtype: S + :param models: Python callables containing Pyro primitives. + :return: efficient influence function for ``functional`` evaluated at ``model`` and ``points`` """ - param_eif = linearized(points, *args, **kwargs) - return torch.vmap( - lambda d: torch.func.jvp( - lambda p: func_target(p, *args, **kwargs), (target_params,), (d,) - )[1], - in_dims=0, - randomness="different", - )(param_eif) - - return _fn + if len(models) != len(points): + raise ValueError("mismatch between number of models and points") + + linearized = linearize(*models, **linearize_kwargs) + target = functional(*models) + + # TODO check that target_params == model_params + assert isinstance(target, torch.nn.Module) + target_params, func_target = make_functional_call(target) + + def _fn(*args: P.args, **kwargs: P.kwargs) -> S: + """ + Evaluates the efficient influence function for ``functional`` at each + point in ``points``. + + :return: efficient influence function evaluated at each point in ``points`` or averaged + """ + param_eif = linearized(*points, *args, **kwargs) + return torch.vmap( + lambda d: torch.func.jvp( + lambda p: func_target(p, *args, **kwargs), (target_params,), (d,) + )[1], + in_dims=0, + randomness="different", + )(param_eif) + + return _fn + + return _influence_functional diff --git a/tests/robust/test_handlers.py b/tests/robust/test_handlers.py index f1849959a..6168d563a 100644 --- a/tests/robust/test_handlers.py +++ b/tests/robust/test_handlers.py @@ -56,15 +56,6 @@ def test_one_step_correction_smoke( guide = guide(model) model(), guide() # initialize - one_step = one_step_correction( - PredictiveModel(model, guide), - functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), - max_plate_nesting=max_plate_nesting, - num_samples_outer=num_samples_outer, - num_samples_inner=num_samples_inner, - cg_iters=cg_iters, - ) - with torch.no_grad(): test_datum = { k: v[0] @@ -73,7 +64,16 @@ def test_one_step_correction_smoke( )().items() } - one_step_on_test: Mapping[str, torch.Tensor] = one_step(test_datum) + one_step = one_step_correction( + functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), + test_datum, + max_plate_nesting=max_plate_nesting, + num_samples_outer=num_samples_outer, + num_samples_inner=num_samples_inner, + cg_iters=cg_iters, + )(PredictiveModel(model, guide)) + + one_step_on_test: Mapping[str, torch.Tensor] = one_step() assert len(one_step_on_test) > 0 for k, v in one_step_on_test.items(): assert not torch.isnan(v).any(), f"one_step for {k} had nans" diff --git a/tests/robust/test_ops.py b/tests/robust/test_ops.py index 1bdb2461b..e3d5e5290 100644 --- a/tests/robust/test_ops.py +++ b/tests/robust/test_ops.py @@ -56,15 +56,6 @@ def test_nmc_predictive_influence_smoke( guide = guide(model) model(), guide() # initialize - predictive_eif = influence_fn( - PredictiveModel(model, guide), - functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), - max_plate_nesting=max_plate_nesting, - num_samples_outer=num_samples_outer, - num_samples_inner=num_samples_inner, - cg_iters=cg_iters, - ) - with torch.no_grad(): test_datum = { k: v[0] @@ -73,7 +64,16 @@ def test_nmc_predictive_influence_smoke( )().items() } - test_datum_eif: Mapping[str, torch.Tensor] = predictive_eif(test_datum) + predictive_eif = influence_fn( + functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), + test_datum, + max_plate_nesting=max_plate_nesting, + num_samples_outer=num_samples_outer, + num_samples_inner=num_samples_inner, + cg_iters=cg_iters, + )(PredictiveModel(model, guide)) + + test_datum_eif: Mapping[str, torch.Tensor] = predictive_eif() assert len(test_datum_eif) > 0 for k, v in test_datum_eif.items(): assert not torch.isnan(v).any(), f"eif for {k} had nans" @@ -100,21 +100,21 @@ def test_nmc_predictive_influence_vmap_smoke( model(), guide() # initialize + with torch.no_grad(): + test_data = pyro.infer.Predictive( + model, num_samples=4, return_sites=obs_names, parallel=True + )() + predictive_eif = influence_fn( - PredictiveModel(model, guide), functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), + test_data, max_plate_nesting=max_plate_nesting, num_samples_outer=num_samples_outer, num_samples_inner=num_samples_inner, cg_iters=cg_iters, - ) - - with torch.no_grad(): - test_data = pyro.infer.Predictive( - model, num_samples=4, return_sites=obs_names, parallel=True - )() + )(PredictiveModel(model, guide)) - test_data_eif: Mapping[str, torch.Tensor] = predictive_eif(test_data) + test_data_eif: Mapping[str, torch.Tensor] = predictive_eif() assert len(test_data_eif) > 0 for k, v in test_data_eif.items(): assert not torch.isnan(v).any(), f"eif for {k} had nans"