Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make influence_fn a higher-order Functional #492

Merged
merged 6 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 10 additions & 19 deletions chirho/robust/handlers/estimators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Callable, TypeVar
from typing import TypeVar

from typing_extensions import Concatenate, ParamSpec
from typing_extensions import ParamSpec

from chirho.robust.ops import Functional, Point, influence_fn

Expand All @@ -10,21 +10,17 @@


def one_step_correction(
model: Callable[P, Any],
functional: Functional[P, S],
*test_points: Point[T],
**influence_kwargs,
) -> Callable[Concatenate[Point[T], P], S]:
) -> Functional[P, S]:
"""
Returns a function that computes the one-step correction for the
functional at a specified set of test points as discussed in
[1].
Returns a functional that computes the one-step correction for the
functional at a specified set of test points as discussed in [1].

:param model: Python callable containing Pyro primitives.
:type model: Callable[P, Any]
:param functional: model summary of interest, which is a function of the model.
:type functional: Functional[P, S]
:return: function to compute the one-step correction
:rtype: Callable[Concatenate[Point[T], P], S]
:param functional: model summary functional of interest
:param test_points: points at which to compute the one-step correction
:return: functional to compute the one-step correction

**References**

Expand All @@ -33,9 +29,4 @@ def one_step_correction(
"""
influence_kwargs_one_step = influence_kwargs.copy()
influence_kwargs_one_step["pointwise_influence"] = False
eif_fn = influence_fn(model, functional, **influence_kwargs_one_step)

def _one_step(test_data: Point[T], *args, **kwargs) -> S:
return eif_fn(test_data, *args, **kwargs)

return _one_step
return influence_fn(functional, *test_points, **influence_kwargs_one_step)
8 changes: 6 additions & 2 deletions chirho/robust/internals/linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,7 @@ def jvp_fn(log_prob_params: ParamDict) -> torch.Tensor:


def linearize(
model: Callable[P, Any],
*,
*models: Callable[P, Any],
num_samples_outer: int,
num_samples_inner: Optional[int] = None,
max_plate_nesting: Optional[int] = None,
Expand Down Expand Up @@ -328,6 +327,11 @@ def forward(self):
This issue will be addressed in a future release:
https://github.com/BasisResearch/chirho/issues/393.
"""
if len(models) > 1:
raise NotImplementedError("Only unary version of linearize is implemented.")
else:
(model,) = models

assert isinstance(model, torch.nn.Module)
if num_samples_inner is None:
num_samples_inner = num_samples_outer**2
Expand Down
98 changes: 56 additions & 42 deletions chirho/robust/ops.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,35 @@
import functools
from typing import Any, Callable, Mapping, TypeVar
from typing import Any, Callable, Mapping, Protocol, TypeVar

import torch
from typing_extensions import Concatenate, ParamSpec
from typing_extensions import ParamSpec

from chirho.observational.ops import Observation

P = ParamSpec("P")
Q = ParamSpec("Q")
S = TypeVar("S")
S = TypeVar("S", covariant=True)
T = TypeVar("T")

Point = Mapping[str, Observation[T]]
Functional = Callable[[Callable[P, Any]], Callable[P, S]]


class Functional(Protocol[P, S]):
def __call__(
self, __model: Callable[P, Any], *models: Callable[P, Any]
) -> Callable[P, S]:
...


def influence_fn(
model: Callable[P, Any], functional: Functional[P, S], **linearize_kwargs
) -> Callable[Concatenate[Point[T], P], S]:
functional: Functional[P, S], *points: Point[T], **linearize_kwargs
) -> Functional[P, S]:
"""
Returns the efficient influence function for ``functional``
with respect to the parameters of probabilistic program ``model``.
Returns a new functional that computes the efficient influence function for ``functional``
at the given ``points`` with respect to the parameters of its probabilistic program arguments.

:param model: Python callable containing Pyro primitives.
:type model: Callable[P, Any]
:param functional: model summary of interest, which is a function of ``model``
:type functional: Functional[P, S]
:return: the efficient influence function for ``functional``
:rtype: Callable[Concatenate[Point[T], P], S]
:param points: points for each input to ``functional`` at which to compute the efficient influence function
:return: functional that computes the efficient influence function for ``functional`` at ``points``

**Example usage**:

Expand Down Expand Up @@ -88,14 +90,13 @@ def forward(self):
)
points = predictive()
influence = influence_fn(
model,
guide,
SimpleFunctional,
points,
num_samples_outer=1000,
num_samples_inner=1000,
)
)(PredictiveModel(model, guide))

influence(points)
influence()

.. note::

Expand All @@ -111,31 +112,44 @@ def forward(self):
from chirho.robust.internals.linearize import linearize
from chirho.robust.internals.utils import make_functional_call

linearized = linearize(model, **linearize_kwargs)
target = functional(model)

# TODO check that target_params == model_params
assert isinstance(target, torch.nn.Module)
target_params, func_target = make_functional_call(target)
if len(points) != 1:
raise NotImplementedError(
"influence_fn currently only supports unary functionals"
)

@functools.wraps(target)
def _fn(points: Point[T], *args: P.args, **kwargs: P.kwargs) -> S:
def _influence_functional(*models: Callable[P, Any]) -> Callable[P, S]:
"""
Evaluates the efficient influence function for ``functional`` at each
point in ``points``.
Functional representing the efficient influence function of ``functional`` at ``points`` .

:param points: points at which to compute the efficient influence function
:type points: Point[T]
:return: efficient influence function evaluated at each point in ``points`` or averaged
:rtype: S
:param models: Python callables containing Pyro primitives.
:return: efficient influence function for ``functional`` evaluated at ``model`` and ``points``
"""
param_eif = linearized(points, *args, **kwargs)
return torch.vmap(
lambda d: torch.func.jvp(
lambda p: func_target(p, *args, **kwargs), (target_params,), (d,)
)[1],
in_dims=0,
randomness="different",
)(param_eif)

return _fn
if len(models) != len(points):
raise ValueError("mismatch between number of models and points")

linearized = linearize(*models, **linearize_kwargs)
target = functional(*models)

# TODO check that target_params == model_params
assert isinstance(target, torch.nn.Module)
target_params, func_target = make_functional_call(target)

def _fn(*args: P.args, **kwargs: P.kwargs) -> S:
"""
Evaluates the efficient influence function for ``functional`` at each
point in ``points``.

:return: efficient influence function evaluated at each point in ``points`` or averaged
"""
param_eif = linearized(*points, *args, **kwargs)
return torch.vmap(
lambda d: torch.func.jvp(
lambda p: func_target(p, *args, **kwargs), (target_params,), (d,)
)[1],
in_dims=0,
randomness="different",
)(param_eif)

return _fn

return _influence_functional
20 changes: 10 additions & 10 deletions tests/robust/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,6 @@ def test_one_step_correction_smoke(
guide = guide(model)
model(), guide() # initialize

one_step = one_step_correction(
PredictiveModel(model, guide),
functools.partial(PredictiveFunctional, num_samples=num_predictive_samples),
max_plate_nesting=max_plate_nesting,
num_samples_outer=num_samples_outer,
num_samples_inner=num_samples_inner,
cg_iters=cg_iters,
)

with torch.no_grad():
test_datum = {
k: v[0]
Expand All @@ -73,7 +64,16 @@ def test_one_step_correction_smoke(
)().items()
}

one_step_on_test: Mapping[str, torch.Tensor] = one_step(test_datum)
one_step = one_step_correction(
functools.partial(PredictiveFunctional, num_samples=num_predictive_samples),
test_datum,
max_plate_nesting=max_plate_nesting,
num_samples_outer=num_samples_outer,
num_samples_inner=num_samples_inner,
cg_iters=cg_iters,
)(PredictiveModel(model, guide))

one_step_on_test: Mapping[str, torch.Tensor] = one_step()
assert len(one_step_on_test) > 0
for k, v in one_step_on_test.items():
assert not torch.isnan(v).any(), f"one_step for {k} had nans"
Expand Down
36 changes: 18 additions & 18 deletions tests/robust/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,6 @@ def test_nmc_predictive_influence_smoke(
guide = guide(model)
model(), guide() # initialize

predictive_eif = influence_fn(
PredictiveModel(model, guide),
functools.partial(PredictiveFunctional, num_samples=num_predictive_samples),
max_plate_nesting=max_plate_nesting,
num_samples_outer=num_samples_outer,
num_samples_inner=num_samples_inner,
cg_iters=cg_iters,
)

with torch.no_grad():
test_datum = {
k: v[0]
Expand All @@ -73,7 +64,16 @@ def test_nmc_predictive_influence_smoke(
)().items()
}

test_datum_eif: Mapping[str, torch.Tensor] = predictive_eif(test_datum)
predictive_eif = influence_fn(
functools.partial(PredictiveFunctional, num_samples=num_predictive_samples),
test_datum,
max_plate_nesting=max_plate_nesting,
num_samples_outer=num_samples_outer,
num_samples_inner=num_samples_inner,
cg_iters=cg_iters,
)(PredictiveModel(model, guide))

test_datum_eif: Mapping[str, torch.Tensor] = predictive_eif()
assert len(test_datum_eif) > 0
for k, v in test_datum_eif.items():
assert not torch.isnan(v).any(), f"eif for {k} had nans"
Expand All @@ -100,21 +100,21 @@ def test_nmc_predictive_influence_vmap_smoke(

model(), guide() # initialize

with torch.no_grad():
test_data = pyro.infer.Predictive(
model, num_samples=4, return_sites=obs_names, parallel=True
)()

predictive_eif = influence_fn(
PredictiveModel(model, guide),
functools.partial(PredictiveFunctional, num_samples=num_predictive_samples),
test_data,
max_plate_nesting=max_plate_nesting,
num_samples_outer=num_samples_outer,
num_samples_inner=num_samples_inner,
cg_iters=cg_iters,
)

with torch.no_grad():
test_data = pyro.infer.Predictive(
model, num_samples=4, return_sites=obs_names, parallel=True
)()
)(PredictiveModel(model, guide))

test_data_eif: Mapping[str, torch.Tensor] = predictive_eif(test_data)
test_data_eif: Mapping[str, torch.Tensor] = predictive_eif()
assert len(test_data_eif) > 0
for k, v in test_data_eif.items():
assert not torch.isnan(v).any(), f"eif for {k} had nans"
Expand Down
Loading