Skip to content

Commit

Permalink
Finite Difference Baseline (#508)
Browse files Browse the repository at this point in the history
* added robust folder

* uncommited scratch work for log prob

* untested variational log prob

* uncomitted changes

* uncomitted changes

* pair coding w/ eli

* added tests w/ Eli

* eif

* linting

* moving test autograd to internals and deleted old utils file

* sketch influence implementation

* fix more args

* ops file

* file

* format

* lint

* clean up influence and tests

* make tests more generic

* guess max plate nesting

* linearize

* rename file

* tensor flatten

* predictive eif

* jvp type

* reorganize files

* shrink test case

* move guess_max_plate_nesting

* move cg solver to linearze

* type alias

* test_ops

* basic cg tests

* remove failing test case

* format

* move paramdict up

* remove obsolete test files

* add empty handlers

* add chirho.robust to docs

* fix memory leak in tests

* make typing compatible with python 3.8

* typing_extensions

* add branch to ci

* predictive

* remove imprecise annotation

* Added more tests for `linearize` and `make_empirical_fisher_vp` (#405)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* removed missing import

* fixed failing test with seeding

* addressing Eli's comments

* Add upper bound on number of CG steps (#404)

* upper bound on cg_iters

* address comment

* fixed test for non-symmetric matrix (#437)

* Make `NMCLogPredictiveLikelihood` seeded (#408)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* switched back to different

* Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430)

* hessian vector product formulation for fisher

* ignoring small type error

* fixed linting error

* Add new `SimpleModel` and `SimpleGuide` (#440)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* uncomitted change before branch switch

* switched back to different

* added revised simple model and guide

* added multiple link functions in test

* linting

* Batching in `linearize` and `influence` (#465)

* batching in linearize and influence

* addressing eli's review

* added optimization for pointwise false case

* fixing lint error

* batched cg (#466)

* One step correction implemented (#467)

* one step correction

* increased tolerance

* fixing lint issue

* Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473)

* sketch batched nmc lpd

* nits

* fix type

* format

* comment

* comment

* comment

* typo

* typo

* add condition to help guarantee idempotence

* simplify edge case

* simplify plate_name

* simplify batchedobservation logic

* factorize

* simplify batched

* reorder

* comment

* remove plate_names

* types

* formatting and type

* move unbind to utils

* remove max_plate_nesting arg from get_traces

* comment

* nit

* move get_importance_traces to utils

* fix types

* generic obs type

* lint

* format

* handle observe in batchedobservations

* event dim

* move batching handlers to utils

* replace 2/3 vmaps, tests pass

* remove dead code

* format

* name args

* lint

* shuffle code

* try an extra optimization in batchedlatents

* add another optimization

* undo changes to test

* remove inplace adds

* add performance test showing speedup

* document internal helpers

* batch latents test

* move batch handlers to predictive

* add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel

* use bind_leftmost_dim in log prob

* Added documentation for `chirho.robust` (#470)

* documentation

* documentation clean up w/ eli

* fix lint issue

* Make functional argument to influence_fn required (#487)

* Make functional argument required

* estimator

* docstring

* Remove guide argument from `influence_fn` and `linearize` (#489)

* Make functional argument required

* estimator

* docstring

* Remove guide, make tests pass

* rename internals.predictive to internals.nmc

* expose handlers.predictive

* expose handlers.predictive

* docstrings

* fix doc build

* fix equation

* docstring import

---------

Co-authored-by: Sam Witty <[email protected]>

* Make influence_fn a higher-order Functional (#492)

* make influence a functional

* fix test

* multiple arguments

* doc

* docstring

* docstring

* Add full corrected one step estimator (#476)

* added scaffolding to one step estimator

* kept signature the same as one_step_correction

* lint

* refactored test to include multiple estimators

* typo

* revise error

* added dict handling

* remove assert

* more informative error message

* replace dispatch with pytree flatten and unflatten

* revert arg for influence_function_estimator

* docs and lint

* lingering influence_fn

* fixed missing return

* rename

* lint

* add *model to appease the linter

* add abstractions and simple temp scratch to test with squared unit normal functional with perturbation.

* removes old scratch notebook

* gets squared density running under abstraction that couples functionals and models

* gets quad and mc approximations to match, vectorization hacky.

* adds plotting and comparative to analytic.

* adds scratch experiment comparing squared density analytic vs fd approx across various epsilon lambdas

* fixes dataset splitting, breaks analytic eif

* unfixes an incorrect fix, working now.

* refactors finite difference machinery to fit experimental specs.

* switches to existing rng seed context manager.

* reverts back to what turns out to be a slightly different seeding context.

---------

Co-authored-by: Raj Agrawal <[email protected]>
Co-authored-by: Eli <[email protected]>
Co-authored-by: Sam Witty <[email protected]>
Co-authored-by: Raj Agrawal <[email protected]>
Co-authored-by: eb8680 <[email protected]>
  • Loading branch information
6 people committed Jan 19, 2024
1 parent 325d4b0 commit 72452d1
Show file tree
Hide file tree
Showing 6 changed files with 453 additions and 0 deletions.
Empty file.
174 changes: 174 additions & 0 deletions docs/examples/robust_paper/finite_difference_eif/abstractions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import torch
import pyro
import pyro.distributions as dist
from typing import Dict, Optional
from contextlib import contextmanager
from chirho.robust.ops import Point, T
import numpy as np


class ModelWithMarginalDensity(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def density(self, *args, **kwargs):
# TODO this can probably default to using BatchedNMCLogMarginalLikelihood applied to self,
# but providing here to avail of analytic densities. Or have a constructor that takes a
# regular model and puts the marginal density here.
raise NotImplementedError()

def forward(self, *args, **kwargs):
raise NotImplementedError()


class PrefixMessenger(pyro.poutine.messenger.Messenger):

def __init__(self, prefix: str):
self.prefix = prefix

def _pyro_sample(self, msg) -> None:
msg["name"] = f"{self.prefix}{msg['name']}"


class FDModelFunctionalDensity(ModelWithMarginalDensity):
"""
This class serves to couple the forward sampling model, density, and functional. Finite differencing
operates in the space of densities, and therefore requires of its functionals that they "know about"
the causal structure of the generative model. Thus, the three components are coupled together here.
"""

model: ModelWithMarginalDensity

# TODO These managers are weird but lets you define a valid model at init time and then temporarily
# modify the perturbation later, eg. in the influence function approximatoin.
# TODO pull out boilerplate
@contextmanager
def set_eps(self, eps):
original_eps = self._eps
self._eps = eps
try:
yield
finally:
self._eps = original_eps

@contextmanager
def set_lambda(self, lambda_):
original_lambda = self._lambda
self._lambda = lambda_
try:
yield
finally:
self._lambda = original_lambda

@contextmanager
def set_kernel_point(self, kernel_point: Dict):
original_kernel_point = self._kernel_point
self._kernel_point = kernel_point
try:
yield
finally:
self._kernel_point = original_kernel_point

@property
def kernel(self) -> ModelWithMarginalDensity:
# TODO implementation of a kernel could be brought up to this level. User would need to pass a kernel type
# that's parameterized by the kernel point and lambda.
"""
Inheritors should construct the kernel here as a function of self._kernel_point and self._lambda.
:return:
"""
raise NotImplementedError()

def __init__(self, default_kernel_point: Dict, *args, default_eps=0., default_lambda=0.1, **kwargs):
super().__init__(*args, **kwargs)
self._eps = default_eps
self._lambda = default_lambda
self._kernel_point = default_kernel_point
# TODO don't assume .shape[-1]
self.ndims = np.sum([v.shape[-1] for v in self._kernel_point.values()])

@property
def mixture_weights(self):
return torch.tensor([1. - self._eps, self._eps])

def density(self, model_kwargs: Dict, kernel_kwargs: Dict):
mpart = self.mixture_weights[0] * self.model.density(**model_kwargs)
kpart = self.mixture_weights[1] * self.kernel.density(**kernel_kwargs)
return mpart + kpart

def forward(self, model_kwargs: Optional[Dict] = None, kernel_kwargs: Optional[Dict] = None):
# _from_kernel = pyro.sample('_mixture_assignment', dist.Categorical(self.mixture_weights))
#
# if _from_kernel:
# return self.kernel(**(kernel_kwargs or dict()))
# else:
# return self.model(**(model_kwargs or dict()))

_from_kernel = pyro.sample('_mixture_assignment', dist.Categorical(self.mixture_weights))

kernel_mask = _from_kernel.bool() # Convert to boolean mask

# Apply the respective functions using the masks
with PrefixMessenger('kernel_'), pyro.poutine.trace() as kernel_tr:
kernel_result = self.kernel(**(kernel_kwargs or dict()))
with PrefixMessenger('model_'), pyro.poutine.trace() as model_tr:
model_result = self.model(**(model_kwargs or dict()))

# FIXME to make log likelihoods work properly, the log likelihoods need to be masked/not added
# for particular elements. See e.g. MaskedMixture for a non-general example of how to do this (it
# uses torch distributions instead of arbitrary probabilistic programs.
# https://docs.pyro.ai/en/stable/distributions.html?highlight=MaskedMixture#maskedmixture
# FIXME ideally the trace would have elements of the same name as well here.

# FIXME where isn't shape agnostic.

# Use masks to select the appropriate result for each sample
result = torch.where(kernel_mask[:, None], kernel_result, model_result)

return result

def functional(self, *args, **kwargs):
# TODO update docstring to this being build_functional instead of just functional
"""
The functional target for this model. This is tightly coupled to a particular
pyro model because finite differencing operates in the space of densities, and
automatically exploit any structure of the pyro model the functional
is being evaluated with respect to. As such, the functional must be implemented
with the specific structure of coupled pyro model in mind.
:param args:
:param kwargs:
:return: An estimate of the functional for ths model.
"""
raise NotImplementedError()


# TODO move this to chirho/robust/ops.py and resolve signature mismatches? Maybe. The problem is that the ops
# signature (rightly) decouples models and functionals, whereas for finite differencing they must be coupled
# because the functional (in many cases) must know about the causal structure of the model.
def fd_influence_fn(fd_coupling: FDModelFunctionalDensity, points: Point[T], eps: float, lambda_: float):

def _influence_fn(*args, **kwargs):

# Length of first value in points mappping.
len_points = len(list(points.values())[0])
eif_vals = []
for i in range(len_points):
kernel_point = {k: v[i] for k, v in points.items()}

# Evaluate the original functional.
psi_p = fd_coupling.functional(*args, **kwargs)

# Evaluate the functional of the perturbation.
with (fd_coupling.set_eps(eps),
fd_coupling.set_lambda(lambda_),
fd_coupling.set_kernel_point(kernel_point)):
psi_p_eps = fd_coupling.functional(*args, **kwargs)

# Record the finite difference.
eif_vals.append((psi_p_eps - psi_p) / eps)
return eif_vals

return _influence_fn


34 changes: 34 additions & 0 deletions docs/examples/robust_paper/finite_difference_eif/distributions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from .abstractions import ModelWithMarginalDensity, FDModelFunctionalDensity
from scipy.stats import multivariate_normal
import pyro
import pyro.distributions as dist


class MultivariateNormalwDensity(ModelWithMarginalDensity):

def __init__(self, mean, scale_tril, *args, **kwargs):
super().__init__(*args, **kwargs)

self.mean = mean
self.scale_tril = scale_tril

# Convert scale_tril to a covariance matrix.
self.cov = scale_tril @ scale_tril.T

def density(self, x):
return multivariate_normal.pdf(x, mean=self.mean, cov=self.cov)

def forward(self):
return pyro.sample("x", dist.MultivariateNormal(self.mean, scale_tril=self.scale_tril))


class PerturbableNormal(FDModelFunctionalDensity):

def __init__(self, *args, mean, scale_tril, **kwargs):
super().__init__(*args, **kwargs)

self.ndims = mean.shape[-1]
self.model = MultivariateNormalwDensity(
mean=mean,
scale_tril=scale_tril
)
56 changes: 56 additions & 0 deletions docs/examples/robust_paper/finite_difference_eif/mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from .abstractions import FDModelFunctionalDensity
import numpy as np
from scipy.integrate import nquad
import torch
import pyro
from .distributions import MultivariateNormalwDensity


class ExpectedDensityQuadFunctional(FDModelFunctionalDensity):
"""
Compute the squared normal density using quadrature.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def functional(self):
def integrand(*args):
# TODO agnostic to kwarg names.
model_kwargs = kernel_kwargs = dict(x=np.array(args))
return self.density(model_kwargs, kernel_kwargs) ** 2

ndim = self._kernel_point['x'].shape[-1]

return nquad(integrand, [[-np.inf, np.inf]] * ndim)[0]


class ExpectedDensityMCFunctional(FDModelFunctionalDensity):
"""
Compute the squared normal density using Monte Carlo.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def functional(self, nmc=1000):
# TODO agnostic to kwarg names
with pyro.plate('samples', nmc):
points = self()
return torch.mean(self.density(model_kwargs=dict(x=points), kernel_kwargs=dict(x=points)))


class NormalKernel(FDModelFunctionalDensity):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@property
def kernel(self):
# TODO agnostic to names.
mean = self._kernel_point['x']
cov = torch.eye(self.ndims) * self._lambda
return MultivariateNormalwDensity(
mean=mean,
scale_tril=torch.linalg.cholesky(cov)
)
Loading

0 comments on commit 72452d1

Please sign in to comment.