Skip to content

Commit

Permalink
Make influence_fn a higher-order Functional (#492)
Browse files Browse the repository at this point in the history
* make influence a functional

* fix test

* multiple arguments

* doc

* docstring

* docstring
  • Loading branch information
eb8680 authored Jan 11, 2024
1 parent 013d518 commit c4346c8
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 91 deletions.
29 changes: 10 additions & 19 deletions chirho/robust/handlers/estimators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Callable, TypeVar
from typing import TypeVar

from typing_extensions import Concatenate, ParamSpec
from typing_extensions import ParamSpec

from chirho.robust.ops import Functional, Point, influence_fn

Expand All @@ -10,21 +10,17 @@


def one_step_correction(
model: Callable[P, Any],
functional: Functional[P, S],
*test_points: Point[T],
**influence_kwargs,
) -> Callable[Concatenate[Point[T], P], S]:
) -> Functional[P, S]:
"""
Returns a function that computes the one-step correction for the
functional at a specified set of test points as discussed in
[1].
Returns a functional that computes the one-step correction for the
functional at a specified set of test points as discussed in [1].
:param model: Python callable containing Pyro primitives.
:type model: Callable[P, Any]
:param functional: model summary of interest, which is a function of the model.
:type functional: Functional[P, S]
:return: function to compute the one-step correction
:rtype: Callable[Concatenate[Point[T], P], S]
:param functional: model summary functional of interest
:param test_points: points at which to compute the one-step correction
:return: functional to compute the one-step correction
**References**
Expand All @@ -33,9 +29,4 @@ def one_step_correction(
"""
influence_kwargs_one_step = influence_kwargs.copy()
influence_kwargs_one_step["pointwise_influence"] = False
eif_fn = influence_fn(model, functional, **influence_kwargs_one_step)

def _one_step(test_data: Point[T], *args, **kwargs) -> S:
return eif_fn(test_data, *args, **kwargs)

return _one_step
return influence_fn(functional, *test_points, **influence_kwargs_one_step)
8 changes: 6 additions & 2 deletions chirho/robust/internals/linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,7 @@ def jvp_fn(log_prob_params: ParamDict) -> torch.Tensor:


def linearize(
model: Callable[P, Any],
*,
*models: Callable[P, Any],
num_samples_outer: int,
num_samples_inner: Optional[int] = None,
max_plate_nesting: Optional[int] = None,
Expand Down Expand Up @@ -328,6 +327,11 @@ def forward(self):
This issue will be addressed in a future release:
https://github.com/BasisResearch/chirho/issues/393.
"""
if len(models) > 1:
raise NotImplementedError("Only unary version of linearize is implemented.")
else:
(model,) = models

assert isinstance(model, torch.nn.Module)
if num_samples_inner is None:
num_samples_inner = num_samples_outer**2
Expand Down
98 changes: 56 additions & 42 deletions chirho/robust/ops.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,35 @@
import functools
from typing import Any, Callable, Mapping, TypeVar
from typing import Any, Callable, Mapping, Protocol, TypeVar

import torch
from typing_extensions import Concatenate, ParamSpec
from typing_extensions import ParamSpec

from chirho.observational.ops import Observation

P = ParamSpec("P")
Q = ParamSpec("Q")
S = TypeVar("S")
S = TypeVar("S", covariant=True)
T = TypeVar("T")

Point = Mapping[str, Observation[T]]
Functional = Callable[[Callable[P, Any]], Callable[P, S]]


class Functional(Protocol[P, S]):
def __call__(
self, __model: Callable[P, Any], *models: Callable[P, Any]
) -> Callable[P, S]:
...


def influence_fn(
model: Callable[P, Any], functional: Functional[P, S], **linearize_kwargs
) -> Callable[Concatenate[Point[T], P], S]:
functional: Functional[P, S], *points: Point[T], **linearize_kwargs
) -> Functional[P, S]:
"""
Returns the efficient influence function for ``functional``
with respect to the parameters of probabilistic program ``model``.
Returns a new functional that computes the efficient influence function for ``functional``
at the given ``points`` with respect to the parameters of its probabilistic program arguments.
:param model: Python callable containing Pyro primitives.
:type model: Callable[P, Any]
:param functional: model summary of interest, which is a function of ``model``
:type functional: Functional[P, S]
:return: the efficient influence function for ``functional``
:rtype: Callable[Concatenate[Point[T], P], S]
:param points: points for each input to ``functional`` at which to compute the efficient influence function
:return: functional that computes the efficient influence function for ``functional`` at ``points``
**Example usage**:
Expand Down Expand Up @@ -88,14 +90,13 @@ def forward(self):
)
points = predictive()
influence = influence_fn(
model,
guide,
SimpleFunctional,
points,
num_samples_outer=1000,
num_samples_inner=1000,
)
)(PredictiveModel(model, guide))
influence(points)
influence()
.. note::
Expand All @@ -111,31 +112,44 @@ def forward(self):
from chirho.robust.internals.linearize import linearize
from chirho.robust.internals.utils import make_functional_call

linearized = linearize(model, **linearize_kwargs)
target = functional(model)

# TODO check that target_params == model_params
assert isinstance(target, torch.nn.Module)
target_params, func_target = make_functional_call(target)
if len(points) != 1:
raise NotImplementedError(
"influence_fn currently only supports unary functionals"
)

@functools.wraps(target)
def _fn(points: Point[T], *args: P.args, **kwargs: P.kwargs) -> S:
def _influence_functional(*models: Callable[P, Any]) -> Callable[P, S]:
"""
Evaluates the efficient influence function for ``functional`` at each
point in ``points``.
Functional representing the efficient influence function of ``functional`` at ``points`` .
:param points: points at which to compute the efficient influence function
:type points: Point[T]
:return: efficient influence function evaluated at each point in ``points`` or averaged
:rtype: S
:param models: Python callables containing Pyro primitives.
:return: efficient influence function for ``functional`` evaluated at ``model`` and ``points``
"""
param_eif = linearized(points, *args, **kwargs)
return torch.vmap(
lambda d: torch.func.jvp(
lambda p: func_target(p, *args, **kwargs), (target_params,), (d,)
)[1],
in_dims=0,
randomness="different",
)(param_eif)

return _fn
if len(models) != len(points):
raise ValueError("mismatch between number of models and points")

linearized = linearize(*models, **linearize_kwargs)
target = functional(*models)

# TODO check that target_params == model_params
assert isinstance(target, torch.nn.Module)
target_params, func_target = make_functional_call(target)

def _fn(*args: P.args, **kwargs: P.kwargs) -> S:
"""
Evaluates the efficient influence function for ``functional`` at each
point in ``points``.
:return: efficient influence function evaluated at each point in ``points`` or averaged
"""
param_eif = linearized(*points, *args, **kwargs)
return torch.vmap(
lambda d: torch.func.jvp(
lambda p: func_target(p, *args, **kwargs), (target_params,), (d,)
)[1],
in_dims=0,
randomness="different",
)(param_eif)

return _fn

return _influence_functional
20 changes: 10 additions & 10 deletions tests/robust/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,6 @@ def test_one_step_correction_smoke(
guide = guide(model)
model(), guide() # initialize

one_step = one_step_correction(
PredictiveModel(model, guide),
functools.partial(PredictiveFunctional, num_samples=num_predictive_samples),
max_plate_nesting=max_plate_nesting,
num_samples_outer=num_samples_outer,
num_samples_inner=num_samples_inner,
cg_iters=cg_iters,
)

with torch.no_grad():
test_datum = {
k: v[0]
Expand All @@ -73,7 +64,16 @@ def test_one_step_correction_smoke(
)().items()
}

one_step_on_test: Mapping[str, torch.Tensor] = one_step(test_datum)
one_step = one_step_correction(
functools.partial(PredictiveFunctional, num_samples=num_predictive_samples),
test_datum,
max_plate_nesting=max_plate_nesting,
num_samples_outer=num_samples_outer,
num_samples_inner=num_samples_inner,
cg_iters=cg_iters,
)(PredictiveModel(model, guide))

one_step_on_test: Mapping[str, torch.Tensor] = one_step()
assert len(one_step_on_test) > 0
for k, v in one_step_on_test.items():
assert not torch.isnan(v).any(), f"one_step for {k} had nans"
Expand Down
36 changes: 18 additions & 18 deletions tests/robust/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,6 @@ def test_nmc_predictive_influence_smoke(
guide = guide(model)
model(), guide() # initialize

predictive_eif = influence_fn(
PredictiveModel(model, guide),
functools.partial(PredictiveFunctional, num_samples=num_predictive_samples),
max_plate_nesting=max_plate_nesting,
num_samples_outer=num_samples_outer,
num_samples_inner=num_samples_inner,
cg_iters=cg_iters,
)

with torch.no_grad():
test_datum = {
k: v[0]
Expand All @@ -73,7 +64,16 @@ def test_nmc_predictive_influence_smoke(
)().items()
}

test_datum_eif: Mapping[str, torch.Tensor] = predictive_eif(test_datum)
predictive_eif = influence_fn(
functools.partial(PredictiveFunctional, num_samples=num_predictive_samples),
test_datum,
max_plate_nesting=max_plate_nesting,
num_samples_outer=num_samples_outer,
num_samples_inner=num_samples_inner,
cg_iters=cg_iters,
)(PredictiveModel(model, guide))

test_datum_eif: Mapping[str, torch.Tensor] = predictive_eif()
assert len(test_datum_eif) > 0
for k, v in test_datum_eif.items():
assert not torch.isnan(v).any(), f"eif for {k} had nans"
Expand All @@ -100,21 +100,21 @@ def test_nmc_predictive_influence_vmap_smoke(

model(), guide() # initialize

with torch.no_grad():
test_data = pyro.infer.Predictive(
model, num_samples=4, return_sites=obs_names, parallel=True
)()

predictive_eif = influence_fn(
PredictiveModel(model, guide),
functools.partial(PredictiveFunctional, num_samples=num_predictive_samples),
test_data,
max_plate_nesting=max_plate_nesting,
num_samples_outer=num_samples_outer,
num_samples_inner=num_samples_inner,
cg_iters=cg_iters,
)

with torch.no_grad():
test_data = pyro.infer.Predictive(
model, num_samples=4, return_sites=obs_names, parallel=True
)()
)(PredictiveModel(model, guide))

test_data_eif: Mapping[str, torch.Tensor] = predictive_eif(test_data)
test_data_eif: Mapping[str, torch.Tensor] = predictive_eif()
assert len(test_data_eif) > 0
for k, v in test_data_eif.items():
assert not torch.isnan(v).any(), f"eif for {k} had nans"
Expand Down

0 comments on commit c4346c8

Please sign in to comment.