Skip to content

Commit

Permalink
Staging branch for chirho.robust module development (#398)
Browse files Browse the repository at this point in the history
* added robust folder

* uncommited scratch work for log prob

* untested variational log prob

* uncomitted changes

* uncomitted changes

* pair coding w/ eli

* added tests w/ Eli

* eif

* linting

* moving test autograd to internals and deleted old utils file

* sketch influence implementation

* fix more args

* ops file

* file

* format

* lint

* clean up influence and tests

* make tests more generic

* guess max plate nesting

* linearize

* rename file

* tensor flatten

* predictive eif

* jvp type

* reorganize files

* shrink test case

* move guess_max_plate_nesting

* move cg solver to linearze

* type alias

* test_ops

* basic cg tests

* remove failing test case

* format

* move paramdict up

* remove obsolete test files

* add empty handlers

* add chirho.robust to docs

* fix memory leak in tests

* make typing compatible with python 3.8

* typing_extensions

* add branch to ci

* predictive

* remove imprecise annotation

* Added more tests for `linearize` and `make_empirical_fisher_vp` (#405)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* removed missing import

* fixed failing test with seeding

* addressing Eli's comments

* Add upper bound on number of CG steps (#404)

* upper bound on cg_iters

* address comment

* fixed test for non-symmetric matrix (#437)

* Make `NMCLogPredictiveLikelihood` seeded (#408)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* switched back to different

* Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430)

* hessian vector product formulation for fisher

* ignoring small type error

* fixed linting error

* Add new `SimpleModel` and `SimpleGuide` (#440)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* uncomitted change before branch switch

* switched back to different

* added revised simple model and guide

* added multiple link functions in test

* linting

* Batching in `linearize` and `influence` (#465)

* batching in linearize and influence

* addressing eli's review

* added optimization for pointwise false case

* fixing lint error

* batched cg (#466)

* One step correction implemented (#467)

* one step correction

* increased tolerance

* fixing lint issue

* Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473)

* sketch batched nmc lpd

* nits

* fix type

* format

* comment

* comment

* comment

* typo

* typo

* add condition to help guarantee idempotence

* simplify edge case

* simplify plate_name

* simplify batchedobservation logic

* factorize

* simplify batched

* reorder

* comment

* remove plate_names

* types

* formatting and type

* move unbind to utils

* remove max_plate_nesting arg from get_traces

* comment

* nit

* move get_importance_traces to utils

* fix types

* generic obs type

* lint

* format

* handle observe in batchedobservations

* event dim

* move batching handlers to utils

* replace 2/3 vmaps, tests pass

* remove dead code

* format

* name args

* lint

* shuffle code

* try an extra optimization in batchedlatents

* add another optimization

* undo changes to test

* remove inplace adds

* add performance test showing speedup

* document internal helpers

* batch latents test

* move batch handlers to predictive

* add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel

* use bind_leftmost_dim in log prob

* Added documentation for `chirho.robust` (#470)

* documentation

* documentation clean up w/ eli

* fix lint issue

* Make functional argument to influence_fn required (#487)

* Make functional argument required

* estimator

* docstring

* Remove guide argument from `influence_fn` and `linearize` (#489)

* Make functional argument required

* estimator

* docstring

* Remove guide, make tests pass

* rename internals.predictive to internals.nmc

* expose handlers.predictive

* expose handlers.predictive

* docstrings

* fix doc build

* fix equation

* docstring import

---------

Co-authored-by: Sam Witty <[email protected]>

* Make influence_fn a higher-order Functional (#492)

* make influence a functional

* fix test

* multiple arguments

* doc

* docstring

* docstring

* Add full corrected one step estimator (#476)

* added scaffolding to one step estimator

* kept signature the same as one_step_correction

* lint

* refactored test to include multiple estimators

* typo

* revise error

* added dict handling

* remove assert

* more informative error message

* replace dispatch with pytree flatten and unflatten

* revert arg for influence_function_estimator

* docs and lint

* lingering influence_fn

* fixed missing return

* rename

* lint

* add *model to appease the linter

---------

Co-authored-by: Raj Agrawal <[email protected]>
Co-authored-by: Eli <[email protected]>
Co-authored-by: Raj Agrawal <[email protected]>
Co-authored-by: eb8680 <[email protected]>
  • Loading branch information
5 people committed Jan 12, 2024
1 parent 804b846 commit fc4a90d
Show file tree
Hide file tree
Showing 20 changed files with 2,478 additions and 3 deletions.
8 changes: 5 additions & 3 deletions chirho/observational/handlers/condition.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations

from typing import Callable, Generic, Hashable, Mapping, TypeVar, Union
from typing import Callable, Generic, Mapping, TypeVar, Union

import pyro
import torch

from chirho.observational.internals import ObserveNameMessenger
from chirho.observational.ops import AtomicObservation, observe
from chirho.observational.ops import Observation, observe

T = TypeVar("T")
R = Union[float, torch.Tensor]
Expand Down Expand Up @@ -64,7 +64,9 @@ class Observations(Generic[T], ObserveNameMessenger):
a richer set of observational data types and enables counterfactual inference.
"""

def __init__(self, data: Mapping[Hashable, AtomicObservation[T]]):
data: Mapping[str, Observation[T]]

def __init__(self, data: Mapping[str, Observation[T]]):
self.data = data
super().__init__()

Expand Down
Empty file added chirho/robust/__init__.py
Empty file.
Empty file.
55 changes: 55 additions & 0 deletions chirho/robust/handlers/estimators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Any, Callable, TypeVar

import torch
from typing_extensions import ParamSpec

from chirho.robust.ops import Functional, Point, influence_fn

P = ParamSpec("P")
S = TypeVar("S")
T = TypeVar("T")


def one_step_corrected_estimator(
functional: Functional[P, S],
*test_points: Point[T],
**influence_kwargs,
) -> Functional[P, S]:
"""
Returns a functional that computes the one-step correction for the
functional at a specified set of test points as discussed in [1].
:param functional: model summary functional of interest
:param test_points: points at which to compute the one-step correction
:return: functional to compute the one-step correction
**References**
[1] `Semiparametric doubly robust targeted double machine learning: a review`,
Edward H. Kennedy, 2022.
"""
influence_kwargs_one_step = influence_kwargs.copy()
influence_kwargs_one_step["pointwise_influence"] = False
eif_fn = influence_fn(functional, *test_points, **influence_kwargs_one_step)

def _corrected_functional(*model: Callable[P, Any]) -> Callable[P, S]:
plug_in_estimator = functional(*model)
correction_estimator = eif_fn(*model)

def _estimator(*args, **kwargs) -> S:
plug_in_estimate = plug_in_estimator(*args, **kwargs)
correction = correction_estimator(*args, **kwargs)

flat_plug_in_estimate, treespec = torch.utils._pytree.tree_flatten(
plug_in_estimate
)
flat_correction, _ = torch.utils._pytree.tree_flatten(correction)

return torch.utils._pytree.tree_unflatten(
[a + b for a, b in zip(flat_plug_in_estimate, flat_correction)],
treespec,
)

return _estimator

return _corrected_functional
140 changes: 140 additions & 0 deletions chirho/robust/handlers/predictive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from typing import Any, Callable, Generic, Optional, TypeVar

import pyro
import torch
from typing_extensions import ParamSpec

from chirho.indexed.handlers import IndexPlatesMessenger
from chirho.robust.internals.nmc import BatchedLatents
from chirho.robust.internals.utils import bind_leftmost_dim
from chirho.robust.ops import Point

P = ParamSpec("P")
S = TypeVar("S")
T = TypeVar("T")


class PredictiveModel(Generic[P, T], torch.nn.Module):
"""
Given a Pyro model and guide, constructs a new model that behaves as if
the latent ``sample`` sites in the original model (i.e. the prior)
were replaced by their counterparts in the guide (i.e. the posterior).
.. note:: Sites that only appear in the model are annotated in traces
produced by the predictive model with ``infer={"_model_predictive_site": True}`` .
:param model: Pyro model.
:param guide: Pyro guide.
"""

model: Callable[P, T]
guide: Optional[Callable[P, Any]]

def __init__(
self,
model: Callable[P, T],
guide: Optional[Callable[P, Any]] = None,
):
super().__init__()
self.model = model
self.guide = guide

def forward(self, *args: P.args, **kwargs: P.kwargs) -> T:
"""
Returns a sample from the posterior predictive distribution.
:return: Sample from the posterior predictive distribution.
:rtype: T
"""
with pyro.poutine.infer_config(
config_fn=lambda msg: {"_model_predictive_site": False}
):
with pyro.poutine.trace() as guide_tr:
if self.guide is not None:
self.guide(*args, **kwargs)

block_guide_sample_sites = pyro.poutine.block(
hide=[
name
for name, node in guide_tr.trace.nodes.items()
if node["type"] == "sample"
]
)

with pyro.poutine.infer_config(
config_fn=lambda msg: {"_model_predictive_site": True}
):
with block_guide_sample_sites:
with pyro.poutine.replay(trace=guide_tr.trace):
return self.model(*args, **kwargs)


class PredictiveFunctional(Generic[P, T], torch.nn.Module):
"""
Functional that returns a batch of samples from the predictive
distribution of a Pyro model. As with ``pyro.infer.Predictive`` ,
the returned values are batched along their leftmost positional dimension.
Similar to ``pyro.infer.Predictive(model, guide, num_samples, parallel=True)``
when :class:`~chirho.robust.handlers.predictive.PredictiveModel` is used to construct
the ``model`` argument and infer the ``sample`` sites whose values should be returned,
and uses :class:`~BatchedLatents` to parallelize over samples from the model.
.. warning:: ``PredictiveFunctional`` currently applies its own internal instance of
:class:`~chirho.indexed.handlers.IndexPlatesMessenger` ,
so it may not behave as expected if used within another enclosing
:class:`~chirho.indexed.handlers.IndexPlatesMessenger` context.
:param model: Pyro model.
:param num_samples: Number of samples to return.
"""

model: Callable[P, Any]
num_samples: int

def __init__(
self,
model: torch.nn.Module,
*,
num_samples: int = 1,
max_plate_nesting: Optional[int] = None,
name: str = "__particles_predictive",
):
super().__init__()
self.model = model
self.num_samples = num_samples
self._first_available_dim = (
-max_plate_nesting - 1 if max_plate_nesting is not None else None
)
self._mc_plate_name = name

def forward(self, *args: P.args, **kwargs: P.kwargs) -> Point[T]:
"""
Returns a batch of samples from the posterior predictive distribution.
:return: Dictionary of samples from the posterior predictive distribution.
:rtype: Point[T]
"""
with IndexPlatesMessenger(first_available_dim=self._first_available_dim):
with pyro.poutine.trace() as model_tr:
with BatchedLatents(self.num_samples, name=self._mc_plate_name):
with pyro.poutine.infer_config(
config_fn=lambda msg: {
"_model_predictive_site": msg["infer"].get(
"_model_predictive_site", True
)
}
):
self.model(*args, **kwargs)

return {
name: bind_leftmost_dim(
node["value"],
self._mc_plate_name,
event_dim=len(node["fn"].event_shape),
)
for name, node in model_tr.trace.nodes.items()
if node["type"] == "sample"
and not pyro.poutine.util.site_is_subsample(node)
and node["infer"].get("_model_predictive_site", False)
}
Empty file.
Loading

0 comments on commit fc4a90d

Please sign in to comment.