Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make influence_fn return a Functional #488

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 11 additions & 18 deletions chirho/robust/handlers/estimators.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,25 @@
from typing import Any, Callable
from typing import TypeVar

from typing_extensions import Concatenate
from typing_extensions import ParamSpec

from chirho.robust.ops import Functional, P, Point, S, T, influence_fn
from chirho.robust.ops import Functional, Point, influence_fn

P = ParamSpec("P")
S = TypeVar("S")
T = TypeVar("T")


def one_step_correction(
model: Callable[P, Any],
guide: Callable[P, Any],
functional: Functional[P, S],
test_data: Point[T],
**influence_kwargs,
) -> Callable[Concatenate[Point[T], P], S]:
) -> Functional[P, S]:
"""
Returns a function that computes the one-step correction for the
functional at a specified set of test points as discussed in
[1].

:param model: Python callable containing Pyro primitives.
:type model: Callable[P, Any]
:param guide: Python callable containing Pyro primitives.
:type guide: Callable[P, Any]
:param functional: model summary of interest, which is a function of the
model and guide.
:param functional: model summary of interest
:type functional: Functional[P, S]
:return: function to compute the one-step correction
:rtype: Callable[Concatenate[Point[T], P], S]
Expand All @@ -33,9 +31,4 @@ def one_step_correction(
"""
influence_kwargs_one_step = influence_kwargs.copy()
influence_kwargs_one_step["pointwise_influence"] = False
eif_fn = influence_fn(model, guide, functional, **influence_kwargs_one_step)

def _one_step(test_data: Point[T], *args, **kwargs) -> S:
return eif_fn(test_data, *args, **kwargs)

return _one_step
return influence_fn(functional, test_data, **influence_kwargs_one_step)
22 changes: 8 additions & 14 deletions chirho/robust/internals/linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,18 +220,18 @@ def jvp_fn(log_prob_params: ParamDict) -> torch.Tensor:

def linearize(
model: Callable[P, Any],
guide: Callable[P, Any],
points: Point[T],
*,
num_samples_outer: int,
num_samples_inner: Optional[int] = None,
max_plate_nesting: Optional[int] = None,
cg_iters: Optional[int] = None,
residual_tol: float = 1e-4,
pointwise_influence: bool = True,
) -> Callable[Concatenate[Point[T], P], ParamDict]:
) -> Callable[P, ParamDict]:
r"""
Returns the influence function associated with the parameters
of ``guide`` and probabilistic program ``model``. This function
of probabilistic program ``model``. This function
computes the following quantity at an arbitrary point :math:`x^{\prime}`:

.. math::
Expand All @@ -248,9 +248,6 @@ def linearize(

:param model: Python callable containing Pyro primitives.
:type model: Callable[P, Any]
:param guide: Python callable containing Pyro primitives.
Must only contain continuous latent variables.
:type guide: Callable[P, Any]
:param num_samples_outer: number of Monte Carlo samples to
approximate Fisher information in :func:`make_empirical_fisher_vp`
:type num_samples_outer: int
Expand All @@ -271,7 +268,7 @@ def linearize(
over ``points``. Defaults to True.
:type pointwise_influence: bool, optional
:return: the influence function associated with the parameters
:rtype: Callable[Concatenate[Point[T], P], ParamDict]
:rtype: Callable[P, ParamDict]

**Example usage**:

Expand Down Expand Up @@ -312,13 +309,13 @@ def forward(self):
)
points = predictive()
influence = linearize(
model,
guide,
PredictiveModel(model, guide),
points,
num_samples_outer=1000,
num_samples_inner=1000,
)

influence(points)
influence()

.. note::

Expand All @@ -332,19 +329,17 @@ def forward(self):
https://github.com/BasisResearch/chirho/issues/393.
"""
assert isinstance(model, torch.nn.Module)
assert isinstance(guide, torch.nn.Module)
if num_samples_inner is None:
num_samples_inner = num_samples_outer**2

predictive = pyro.infer.Predictive(
model,
guide=guide,
num_samples=num_samples_outer,
parallel=True,
)

batched_log_prob = BatchedNMCLogPredictiveLikelihood(
model, guide, num_samples=num_samples_inner, max_plate_nesting=max_plate_nesting
model, num_samples=num_samples_inner, max_plate_nesting=max_plate_nesting
)
log_prob_params, batched_func_log_prob = make_functional_call(batched_log_prob)
log_prob_params_numel: int = sum(p.numel() for p in log_prob_params.values())
Expand All @@ -357,7 +352,6 @@ def forward(self):
)

def _fn(
points: Point[T],
*args: P.args,
**kwargs: P.kwargs,
) -> ParamDict:
Expand Down
17 changes: 9 additions & 8 deletions chirho/robust/internals/predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,12 @@ class PredictiveModel(Generic[P, T], torch.nn.Module):
"""

model: Callable[P, T]
guide: Callable[P, Any]
guide: Optional[Callable[P, Any]]

def __init__(
self,
model: Callable[P, T],
guide: Callable[P, Any],
guide: Optional[Callable[P, Any]] = None,
):
super().__init__()
self.model = model
Expand All @@ -135,7 +135,8 @@ def forward(self, *args: P.args, **kwargs: P.kwargs) -> T:
:rtype: T
"""
with pyro.poutine.trace() as guide_tr:
self.guide(*args, **kwargs)
if self.guide is not None:
self.guide(*args, **kwargs)

block_guide_sample_sites = pyro.poutine.block(
hide=[
Expand Down Expand Up @@ -175,13 +176,13 @@ class PredictiveFunctional(Generic[P, T], torch.nn.Module):
"""

model: Callable[P, Any]
guide: Callable[P, Any]
guide: Optional[Callable[P, Any]]
num_samples: int

def __init__(
self,
model: torch.nn.Module,
guide: torch.nn.Module,
guide: Optional[torch.nn.Module] = None,
*,
num_samples: int = 1,
max_plate_nesting: Optional[int] = None,
Expand Down Expand Up @@ -245,13 +246,13 @@ class BatchedNMCLogPredictiveLikelihood(Generic[P, T], torch.nn.Module):
:type num_samples: int, optional
"""
model: Callable[P, Any]
guide: Callable[P, Any]
guide: Optional[Callable[P, Any]]
num_samples: int

def __init__(
self,
model: torch.nn.Module,
guide: torch.nn.Module,
guide: Optional[torch.nn.Module] = None,
*,
num_samples: int = 1,
max_plate_nesting: Optional[int] = None,
Expand Down Expand Up @@ -279,7 +280,7 @@ def forward(
:return: Log predictive likelihood at each datapoint.
:rtype: torch.Tensor
"""
get_nmc_traces = get_importance_traces(PredictiveModel(self.model, self.guide))
get_nmc_traces = get_importance_traces(self.model, self.guide)

with IndexPlatesMessenger(first_available_dim=self._first_available_dim):
with BatchedLatents(self.num_samples, name=self._mc_plate_name):
Expand Down
88 changes: 40 additions & 48 deletions chirho/robust/ops.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import functools
from typing import Any, Callable, Mapping, TypeVar

import torch
from typing_extensions import Concatenate, ParamSpec
from typing_extensions import ParamSpec

from chirho.observational.ops import Observation

Expand All @@ -12,30 +11,24 @@
T = TypeVar("T")

Point = Mapping[str, Observation[T]]
Functional = Callable[[Callable[P, Any], Callable[P, Any]], Callable[P, S]]
Functional = Callable[[Callable[P, Any]], Callable[P, S]]


def influence_fn(
model: Callable[P, Any],
guide: Callable[P, Any],
functional: Functional[P, S],
**linearize_kwargs
) -> Callable[Concatenate[Point[T], P], S]:
functional: Functional[P, S], points: Point[T], **linearize_kwargs
) -> Functional[P, S]:
"""
Returns the efficient influence function for ``functional``
with respect to the parameters of ``guide`` and probabilistic
program ``model``.
with respect to the parameters of probabilistic program ``model``.

:param model: Python callable containing Pyro primitives.
:type model: Callable[P, Any]
:param guide: Python callable containing Pyro primitives.
Must only contain continuous latent variables.
:type guide: Callable[P, Any]
:param functional: model summary of interest, which is a function of the
model and guide.
:param functional: model summary of interest, which is a function of the model.
:type functional: Functional[P, S]
:param points: points at which to compute the efficient influence function
:type points: Point[T]
:return: the efficient influence function for ``functional``
:rtype: Callable[Concatenate[Point[T], P], S]
:rtype: Functional[P, S]

**Example usage**:

Expand Down Expand Up @@ -95,14 +88,13 @@ def forward(self):
)
points = predictive()
influence = influence_fn(
model,
guide,
SimpleFunctional,
points,
num_samples_outer=1000,
num_samples_inner=1000,
)
)(PredictiveModel(model, guide))

influence(points)
influence()

.. note::

Expand All @@ -118,31 +110,31 @@ def forward(self):
from chirho.robust.internals.linearize import linearize
from chirho.robust.internals.utils import make_functional_call

linearized = linearize(model, guide, **linearize_kwargs)
target = functional(model, guide)

# TODO check that target_params == model_params | guide_params
assert isinstance(target, torch.nn.Module)
target_params, func_target = make_functional_call(target)

@functools.wraps(target)
def _fn(points: Point[T], *args: P.args, **kwargs: P.kwargs) -> S:
"""
Evaluates the efficient influence function for ``functional`` at each
point in ``points``.

:param points: points at which to compute the efficient influence function
:type points: Point[T]
:return: efficient influence function evaluated at each point in ``points`` or averaged
:rtype: S
"""
param_eif = linearized(points, *args, **kwargs)
return torch.vmap(
lambda d: torch.func.jvp(
lambda p: func_target(p, *args, **kwargs), (target_params,), (d,)
)[1],
in_dims=0,
randomness="different",
)(param_eif)

return _fn
def _functional(model: Callable[P, Any]) -> Callable[P, S]:
linearized = linearize(model, points, **linearize_kwargs)
target = functional(model)

# TODO check that target_params == model_params | guide_params
assert isinstance(target, torch.nn.Module)
target_params, func_target = make_functional_call(target)

def _fn(*args: P.args, **kwargs: P.kwargs) -> S:
"""
Evaluates the efficient influence function for ``functional`` at each
point in ``points``.

:return: efficient influence function evaluated at each point in ``points`` or averaged
:rtype: S
"""
param_eif = linearized(*args, **kwargs)
return torch.vmap(
lambda d: torch.func.jvp(
lambda p: func_target(p, *args, **kwargs), (target_params,), (d,)
)[1],
in_dims=0,
randomness="different",
)(param_eif)

return _fn

return _functional
25 changes: 11 additions & 14 deletions tests/robust/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing_extensions import ParamSpec

from chirho.robust.handlers.estimators import one_step_correction
from chirho.robust.internals.predictive import PredictiveFunctional
from chirho.robust.internals.predictive import PredictiveFunctional, PredictiveModel

from .robust_fixtures import SimpleGuide, SimpleModel

Expand Down Expand Up @@ -56,18 +56,6 @@ def test_one_step_correction_smoke(
guide = guide(model)
model(), guide() # initialize

one_step = one_step_correction(
model,
guide,
functional=functools.partial(
PredictiveFunctional, num_samples=num_predictive_samples
),
max_plate_nesting=max_plate_nesting,
num_samples_outer=num_samples_outer,
num_samples_inner=num_samples_inner,
cg_iters=cg_iters,
)

with torch.no_grad():
test_datum = {
k: v[0]
Expand All @@ -76,7 +64,16 @@ def test_one_step_correction_smoke(
)().items()
}

one_step_on_test: Mapping[str, torch.Tensor] = one_step(test_datum)
one_step = one_step_correction(
functools.partial(PredictiveFunctional, num_samples=num_predictive_samples),
test_datum,
max_plate_nesting=max_plate_nesting,
num_samples_outer=num_samples_outer,
num_samples_inner=num_samples_inner,
cg_iters=cg_iters,
)(PredictiveModel(model, guide))

one_step_on_test: Mapping[str, torch.Tensor] = one_step()
assert len(one_step_on_test) > 0
for k, v in one_step_on_test.items():
assert not torch.isnan(v).any(), f"one_step for {k} had nans"
Expand Down
Loading