diff --git a/chirho/robust/handlers/estimators.py b/chirho/robust/handlers/estimators.py
index eb6e8d6ee..779b108e8 100644
--- a/chirho/robust/handlers/estimators.py
+++ b/chirho/robust/handlers/estimators.py
@@ -1,8 +1,13 @@
+import copy
+import warnings
from typing import Any, Callable, TypeVar
import torch
+import torchopt
from typing_extensions import ParamSpec
+from chirho.robust.handlers.predictive import PredictiveFunctional
+from chirho.robust.internals.utils import make_functional_call
from chirho.robust.ops import Functional, Point, influence_fn
P = ParamSpec("P")
@@ -10,9 +15,216 @@
T = TypeVar("T")
+def tmle_scipy_optimize_wrapper(
+ packed_influence, log_jitter: float = 1e-6
+) -> torch.Tensor:
+ import numpy as np
+ import scipy
+ from scipy.optimize import LinearConstraint
+
+ # Turn things into numpy. This makes us sad... :(
+ D = packed_influence.detach().numpy()
+
+ N, L = D.shape[0], D.shape[1]
+
+ def loss(epsilon):
+ correction = 1 + D.dot(epsilon)
+
+ return -np.sum(np.log(np.maximum(correction, log_jitter)))
+
+ positive_density_constraint = LinearConstraint(
+ D, -1 * np.ones(N), np.inf * np.ones(N)
+ )
+
+ epsilon_solve = scipy.optimize.minimize(
+ loss, np.zeros(L, dtype=D.dtype), constraints=positive_density_constraint
+ )
+
+ if not epsilon_solve.success:
+ warnings.warn("TMLE optimization did not converge.", RuntimeWarning)
+
+ # Convert epsilon back to torch. This makes us happy... :)
+ packed_epsilon = torch.tensor(epsilon_solve.x, dtype=packed_influence.dtype)
+
+ return packed_epsilon
+
+
+# TODO: revert influence_estimator to influence_fn and use handlers for influence_fn
+def tmle(
+ functional: Functional[P, S],
+ test_point: Point,
+ learning_rate: float = 1e-5,
+ n_grad_steps: int = 100,
+ n_tmle_steps: int = 1,
+ num_nmc_samples: int = 1000,
+ num_grad_samples: int = 1000,
+ log_jitter: float = 1e-6,
+ verbose: bool = False,
+ influence_estimator: Callable[
+ [Functional[P, S], Point[T]], Functional[P, S]
+ ] = influence_fn,
+ **influence_kwargs,
+) -> Functional[P, S]:
+ from chirho.robust.internals.nmc import BatchedNMCLogMarginalLikelihood
+
+ def _solve_epsilon(prev_model: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
+ # find epsilon that minimizes the corrected density on test data
+
+ influence_at_test = influence_estimator(
+ functional, test_point, **influence_kwargs
+ )(prev_model)(*args, **kwargs)
+
+ flat_influence_at_test, _ = torch.utils._pytree.tree_flatten(influence_at_test)
+
+ N = flat_influence_at_test[0].shape[0]
+
+ packed_influence_at_test = torch.concatenate(
+ [i.reshape(N, -1) for i in flat_influence_at_test]
+ )
+
+ packed_epsilon = tmle_scipy_optimize_wrapper(packed_influence_at_test)
+
+ return packed_epsilon
+
+ def _solve_model_projection(
+ packed_epsilon: torch.Tensor,
+ prev_model: torch.nn.Module,
+ *args,
+ **kwargs,
+ ) -> torch.nn.Module:
+ prev_params, functional_model = make_functional_call(
+ PredictiveFunctional(prev_model, num_samples=num_grad_samples)
+ )
+ prev_params = {k: v.detach() for k, v in prev_params.items()}
+
+ # Sample data from the model. Note that we only sample once during projection.
+ data = {
+ k: v
+ for k, v in functional_model(prev_params, *args, **kwargs).items()
+ if k in test_point
+ }
+
+ batched_log_prob: torch.nn.Module = BatchedNMCLogMarginalLikelihood(
+ prev_model, num_samples=num_nmc_samples
+ )
+
+ _, log_p_phi = make_functional_call(batched_log_prob)
+
+ influence_at_data = influence_estimator(functional, data, **influence_kwargs)(
+ prev_model
+ )(*args, **kwargs)
+ flat_influence_at_data, _ = torch.utils._pytree.tree_flatten(influence_at_data)
+ N_x = flat_influence_at_data[0].shape[0]
+
+ packed_influence_at_data = torch.concatenate(
+ [i.reshape(N_x, -1) for i in flat_influence_at_data]
+ ).detach()
+
+ log_likelihood_correction = torch.log(
+ torch.maximum(
+ 1 + packed_influence_at_data.mv(packed_epsilon),
+ torch.tensor(log_jitter),
+ )
+ ).detach()
+ if verbose:
+ influence_at_test = influence_estimator(
+ functional, test_point, **influence_kwargs
+ )(prev_model)(*args, **kwargs)
+ flat_influence_at_test, _ = torch.utils._pytree.tree_flatten(
+ influence_at_test
+ )
+ N = flat_influence_at_test[0].shape[0]
+
+ packed_influence_at_test = torch.concatenate(
+ [i.reshape(N, -1) for i in flat_influence_at_test]
+ ).detach()
+
+ log_likelihood_correction_at_test = torch.log(
+ torch.maximum(
+ 1 + packed_influence_at_test.mv(packed_epsilon),
+ torch.tensor(log_jitter),
+ )
+ )
+
+ print("previous log prob at test", log_p_phi(prev_params, test_point).sum())
+ print(
+ "new log prob at test",
+ (
+ log_p_phi(prev_params, test_point)
+ + log_likelihood_correction_at_test
+ ).sum(),
+ )
+
+ log_p_epsilon_at_data = (
+ log_likelihood_correction + log_p_phi(prev_params, data)
+ ).detach()
+
+ def loss(new_params):
+ log_p_phi_at_data = log_p_phi(new_params, data)
+ return torch.sum((log_p_phi_at_data - log_p_epsilon_at_data) ** 2)
+
+ grad_fn = torch.func.grad(loss)
+
+ new_params = {
+ k: v.clone().detach().requires_grad_(True) for k, v in prev_params.items()
+ }
+
+ optimizer = torchopt.adam(lr=learning_rate)
+
+ optimizer_state = optimizer.init(new_params)
+
+ for i in range(n_grad_steps):
+ grad = grad_fn(new_params)
+ if verbose and i % 100 == 0:
+ print(f"inner_iteration_{i}_loss", loss(new_params))
+ for parameter_name, parameter in prev_model.named_parameters():
+ parameter.data = new_params[f"model.{parameter_name}"]
+
+ estimate = functional(prev_model)(*args, **kwargs)
+ assert isinstance(estimate, torch.Tensor)
+ print(
+ f"inner_iteration_{i}_estimate",
+ estimate.detach().item(),
+ )
+ updates, optimizer_state = optimizer.update(
+ grad, optimizer_state, inplace=False
+ )
+ new_params = torchopt.apply_updates(new_params, updates)
+
+ for parameter_name, parameter in prev_model.named_parameters():
+ parameter.data = new_params[f"model.{parameter_name}"]
+
+ return prev_model
+
+ def _corrected_functional(*models: Callable[P, Any]) -> Callable[P, S]:
+ assert len(models) == 1
+ model = models[0]
+
+ assert isinstance(model, torch.nn.Module)
+
+ def _estimator(*args, **kwargs) -> S:
+ tmle_model = copy.deepcopy(model)
+
+ for _ in range(n_tmle_steps):
+ packed_epsilon = _solve_epsilon(tmle_model, *args, **kwargs)
+
+ tmle_model = _solve_model_projection(
+ packed_epsilon, tmle_model, *args, **kwargs
+ )
+ return functional(tmle_model)(*args, **kwargs)
+
+ return _estimator
+
+ return _corrected_functional
+
+
+# TODO: revert influence_estimator to influence_fn and use handlers for influence_fn
def one_step_corrected_estimator(
functional: Functional[P, S],
*test_points: Point[T],
+ influence_estimator: Callable[
+ [Functional[P, S], Point[T]], Functional[P, S]
+ ] = influence_fn,
**influence_kwargs,
) -> Functional[P, S]:
"""
@@ -30,7 +242,7 @@ def one_step_corrected_estimator(
"""
influence_kwargs_one_step = influence_kwargs.copy()
influence_kwargs_one_step["pointwise_influence"] = False
- eif_fn = influence_fn(functional, *test_points, **influence_kwargs_one_step)
+ eif_fn = influence_estimator(functional, *test_points, **influence_kwargs_one_step)
def _corrected_functional(*model: Callable[P, Any]) -> Callable[P, S]:
plug_in_estimator = functional(*model)
diff --git a/docs/source/tmle.ipynb b/docs/source/tmle.ipynb
new file mode 100644
index 000000000..b9fe30e49
--- /dev/null
+++ b/docs/source/tmle.ipynb
@@ -0,0 +1,930 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Automated doubly robust estimation with ChiRho - TMLE Version"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Outline\n",
+ "\n",
+ "- [Setup](#setup)\n",
+ "\n",
+ "- [Overview: Systematically adjusting for observed confounding](#overview:-systematically-adjusting-for-observed-confounding)\n",
+ " - [Task: Treatment effect estimation with observational data](#task:-treatment-effect-estimation-with-observational-data)\n",
+ " - [Challenge: Confounding](#challenge:-confounding)\n",
+ " - [Assumptions: All confounders observed](#assumptions:-all-confounders-observed)\n",
+ " - [Intuition: Statistically adjusting for confounding](#intuition:-statistically-adjusting-for-confounding)\n",
+ "\n",
+ "- [Causal Probabilistic Program](#causal-probabilistic-program)\n",
+ " - [Model description](#model-description)\n",
+ " - [Generating data](#generating-data)\n",
+ " - [Fit parameters via maximum likelihood](#fit-parameters-via-maximum-likelihood)\n",
+ "\n",
+ "- [Causal Query: average treatment effect (ATE)](#causal-query:-average-treatment-effect-\\(ATE\\))\n",
+ " - [Defining the target functional](#defining-the-target-functional)\n",
+ " - [Closed form doubly robust correction](#closed-form-doubly-robust-correction)\n",
+ " - [Computing automated doubly robust correction via Monte Carlo](#computing-automated-doubly-robust-correction-via-monte-carlo)\n",
+ " - [Results](#results)\n",
+ "\n",
+ "- [References](#references)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Setup"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Here, we install the necessary Pytorch, Pyro, and ChiRho dependencies for this example."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "NOTE: Redirects are currently not supported in Windows or MacOs.\n"
+ ]
+ }
+ ],
+ "source": [
+ "from typing import Callable, Optional, Tuple\n",
+ "\n",
+ "import functools\n",
+ "import torch\n",
+ "import math\n",
+ "import seaborn as sns\n",
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "import pyro\n",
+ "import pyro.distributions as dist\n",
+ "from pyro.infer import Predictive\n",
+ "import pyro.contrib.gp as gp\n",
+ "\n",
+ "from chirho.counterfactual.handlers import MultiWorldCounterfactual\n",
+ "from chirho.indexed.ops import IndexSet, gather\n",
+ "from chirho.interventional.handlers import do\n",
+ "from chirho.robust.internals.utils import ParamDict\n",
+ "from chirho.robust.handlers.estimators import one_step_corrected_estimator, tmle\n",
+ "from chirho.robust.handlers.predictive import PredictiveModel \n",
+ "from chirho.robust.ops import influence_fn\n",
+ "\n",
+ "pyro.settings.set(module_local_params=True)\n",
+ "\n",
+ "sns.set_style(\"white\")\n",
+ "\n",
+ "pyro.set_rng_seed(321) # for reproducibility"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Overview\n",
+ "\n",
+ "In this tutorial, we will use ChiRho to estimate the average treatment effect (ATE) from observational data. We will use a simple example to illustrate the basic concepts of doubly robust estimation and how ChiRho can be used to automate the process for more general summaries of interest. \n",
+ "\n",
+ "There are five main steps to our doubly robust estimation procedure but only the last step is different from a standard probabilistic programming workflow:\n",
+ "1. Write model of interest\n",
+ " - Define probabilistic model of interest using Pyro\n",
+ "2. Feed in data\n",
+ " - Observed data used to train the model\n",
+ "3. Run inference\n",
+ " - Use Pyro's rich inference library to fit the model to the data\n",
+ "4. Define target functional\n",
+ " - This is the model summary of interest (e.g. average treatment effect)\n",
+ "5. Compute robust estimate\n",
+ " - Use ChiRho to compute the doubly robust estimate of the target functional\n",
+ " - Importantly, this step is automated and does not require refitting the model for each new functional\n",
+ "\n",
+ "\n",
+ "Our proposed automated robust inference pipeline is summarized in the figure below.\n",
+ "\n",
+ "![fig1](figures/robust_pipeline.png)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Causal Probabilistic Program\n",
+ "\n",
+ "### Model Description\n",
+ "In this example, we will focus on a cannonical model `CausalGLM` consisting of three types of variables: binary treatment (`A`), confounders (`X`), and response (`Y`). For simplicitly, we assume that the response is generated from a generalized linear model with link function $g$. The model is described by the following generative process:\n",
+ "\n",
+ "$$\n",
+ "\\begin{align*}\n",
+ "X &\\sim \\text{Normal}(0, I_p) \\\\\n",
+ "A &\\sim \\text{Bernoulli}(\\pi(X)) \\\\\n",
+ "\\mu &= \\beta_0 + \\beta_1^T X + \\tau A \\\\\n",
+ "Y &\\sim \\text{ExponentialFamily}(\\text{mean} = g^{-1}(\\mu))\n",
+ "\\end{align*}\n",
+ "$$\n",
+ "\n",
+ "where $p$ denotes the number of confounders, $\\pi(X)$ is the probability of treatment conditional on confounders $X$, $\\beta_0$ is the intercept, $\\beta_1$ is the confounder effect, and $\\tau$ is the treatment effect."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class CausalGLM(pyro.nn.PyroModule):\n",
+ " def __init__(\n",
+ " self,\n",
+ " p: int,\n",
+ " N: int,\n",
+ " link_fn: Callable[..., dist.Distribution] = lambda mu: dist.Normal(mu, 1.0),\n",
+ " include_prior: bool = True,\n",
+ " prior_scale: Optional[float] = None,\n",
+ " ):\n",
+ " super().__init__()\n",
+ " self.p = p\n",
+ " self.N = N\n",
+ " self.link_fn = link_fn\n",
+ " self.include_prior = include_prior\n",
+ " if prior_scale is None:\n",
+ " self.prior_scale = 1 / math.sqrt(self.p)\n",
+ " else:\n",
+ " self.prior_scale = prior_scale\n",
+ "\n",
+ " def sample_outcome_weights(self):\n",
+ " return pyro.sample(\n",
+ " \"outcome_weights\",\n",
+ " dist.Normal(0.0, self.prior_scale).expand((self.p,)).to_event(1),\n",
+ " )\n",
+ "\n",
+ " def sample_intercept(self):\n",
+ " return pyro.sample(\"intercept\", dist.Normal(0.0, 1.0))\n",
+ "\n",
+ " def sample_propensity_weights(self):\n",
+ " return pyro.sample(\n",
+ " \"propensity_weights\",\n",
+ " dist.Normal(0.0, self.prior_scale).expand((self.p,)).to_event(1),\n",
+ " )\n",
+ "\n",
+ " def sample_treatment_weight(self):\n",
+ " return pyro.sample(\"treatment_weight\", dist.Normal(0.0, 1.0))\n",
+ "\n",
+ " def sample_covariate_loc_scale(self):\n",
+ " return torch.zeros(self.p), torch.ones(self.p)\n",
+ " \n",
+ " def generate_datum(self, x_loc, x_scale, propensity_weights, outcome_weights, tau, intercept):\n",
+ " X = pyro.sample(\"X\", dist.Normal(x_loc, x_scale).to_event(1))\n",
+ " A = pyro.sample(\n",
+ " \"A\",\n",
+ " dist.Bernoulli(\n",
+ " logits=torch.einsum(\"...i,...i->...\", X, propensity_weights)\n",
+ " ),\n",
+ " )\n",
+ " return pyro.sample(\n",
+ " \"Y\",\n",
+ " self.link_fn(\n",
+ " torch.einsum(\"...i,...i->...\", X, outcome_weights) + A * tau + intercept\n",
+ " ),\n",
+ " )\n",
+ "\n",
+ " def forward(self):\n",
+ " with pyro.poutine.mask(mask=self.include_prior):\n",
+ " intercept = self.sample_intercept()\n",
+ " outcome_weights = self.sample_outcome_weights()\n",
+ " propensity_weights = self.sample_propensity_weights()\n",
+ " tau = self.sample_treatment_weight()\n",
+ " x_loc, x_scale = self.sample_covariate_loc_scale()\n",
+ " with pyro.plate(\"plate\", self.N, dim=-1):\n",
+ " return self.generate_datum(x_loc, x_scale, propensity_weights, outcome_weights, tau, intercept)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we will condition on both treatment and confounders to estimate the causal effect of treatment on the outcome. We will use the following causal probabilistic program to do so:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class ConditionedCausalGLM(CausalGLM):\n",
+ " def __init__(\n",
+ " self,\n",
+ " X: torch.Tensor,\n",
+ " A: torch.Tensor,\n",
+ " Y: torch.Tensor,\n",
+ " link_fn: Callable[..., dist.Distribution] = lambda mu: dist.Normal(mu, 1.0),\n",
+ " include_prior: bool = True,\n",
+ " prior_scale: Optional[float] = None,\n",
+ " ):\n",
+ " p = X.shape[1]\n",
+ " N = X.shape[0]\n",
+ " super().__init__(p, N, link_fn, include_prior, prior_scale)\n",
+ " self.X = X\n",
+ " self.A = A\n",
+ " self.Y = Y\n",
+ "\n",
+ " def forward(self):\n",
+ " with pyro.condition(data={\"X\": self.X, \"A\": self.A, \"Y\": self.Y}):\n",
+ " return super().forward()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Visualize the model\n",
+ "pyro.render_model(\n",
+ " ConditionedCausalGLM(torch.zeros(1, 1), torch.zeros(1), torch.zeros(1)),\n",
+ " render_params=True, \n",
+ " render_distributions=True\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Generating data\n",
+ "\n",
+ "For evaluation, we generate `N_datasets` datasets, each with `N` samples. We compare vanilla estimates of the target functional with the double robust estimates of the target functional across the `N_sims` datasets. We use a similar data generating process as in Kennedy (2022)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class GroundTruthModel(CausalGLM):\n",
+ " def __init__(\n",
+ " self,\n",
+ " p: int,\n",
+ " N: int,\n",
+ " alpha: int,\n",
+ " beta: int,\n",
+ " link_fn: Callable[..., dist.Distribution] = lambda mu: dist.Normal(mu, 1.0),\n",
+ " treatment_weight: float = 0.0,\n",
+ " ):\n",
+ " super().__init__(p, N, link_fn)\n",
+ " self.alpha = alpha # sparsity of propensity weights\n",
+ " self.beta = beta # sparsity of outcome weights\n",
+ " self.treatment_weight = treatment_weight\n",
+ "\n",
+ " def sample_outcome_weights(self):\n",
+ " outcome_weights = 1 / math.sqrt(self.beta) * torch.ones(self.p)\n",
+ " outcome_weights[self.beta :] = 0.0\n",
+ " return outcome_weights\n",
+ "\n",
+ " def sample_propensity_weights(self):\n",
+ " propensity_weights = 1 / math.sqrt(self.alpha) * torch.ones(self.p)\n",
+ " propensity_weights[self.alpha :] = 0.0\n",
+ " return propensity_weights\n",
+ "\n",
+ " def sample_treatment_weight(self):\n",
+ " return torch.tensor(self.treatment_weight)\n",
+ "\n",
+ " def sample_intercept(self):\n",
+ " return torch.tensor(0.0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "N_datasets = 50\n",
+ "simulated_datasets = []\n",
+ "\n",
+ "# Data configuration\n",
+ "p = 200\n",
+ "alpha = 50\n",
+ "beta = 50\n",
+ "N_train = 500\n",
+ "N_test = 500\n",
+ "treatment_weight = 1.0\n",
+ "\n",
+ "true_model = GroundTruthModel(p, N_train+N_test, alpha, beta, treatment_weight=treatment_weight)\n",
+ "prior_model = CausalGLM(p, N_train+N_test)\n",
+ "\n",
+ "# Generate data\n",
+ "D = Predictive(true_model, num_samples=N_datasets, return_sites=[\"X\", \"A\", \"Y\"], parallel=True)()\n",
+ "D_train = {k: v[:, :N_train] for k, v in D.items()}\n",
+ "D_test = {k: v[:, N_train:] for k, v in D.items()}\n",
+ "\n",
+ "# D_train : (N_datasets, N_train, p)\n",
+ "# D_test : (N_datasets, N_test, p)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Fit parameters via maximum likelihood"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "0\n",
+ "tensor(143220.2969, grad_fn=)\n",
+ "tensor(142401.8281, grad_fn=)\n",
+ "tensor(142398.5000, grad_fn=)\n",
+ "tensor(142398.4844, grad_fn=)\n",
+ "tensor(142398.4844, grad_fn=)\n",
+ "tensor(142398.4844, grad_fn=)\n",
+ "tensor(142398.5000, grad_fn=)\n",
+ "tensor(142398.4844, grad_fn=)\n",
+ "tensor(142398.4844, grad_fn=)\n",
+ "tensor(142398.4844, grad_fn=)\n",
+ "tensor(142398.4844, grad_fn=)\n",
+ "tensor(142398.4844, grad_fn=)\n",
+ "tensor(142398.5312, grad_fn=)\n",
+ "tensor(142398.4844, grad_fn=)\n",
+ "tensor(142398.4844, grad_fn=)\n",
+ "tensor(142398.5000, grad_fn=)\n",
+ "tensor(142398.6094, grad_fn=)\n",
+ "tensor(142398.4844, grad_fn=)\n",
+ "tensor(142398.6094, grad_fn=)\n",
+ "tensor(142398.6562, grad_fn=)\n",
+ "1\n",
+ "tensor(142864.0625, grad_fn=)\n",
+ "tensor(142156.9531, grad_fn=)\n",
+ "tensor(142155.7812, grad_fn=)\n",
+ "tensor(142155.7812, grad_fn=)\n",
+ "tensor(142155.7812, grad_fn=)\n",
+ "tensor(142155.7812, grad_fn=)\n",
+ "tensor(142155.7812, grad_fn=)\n",
+ "tensor(142155.7812, grad_fn=)\n",
+ "tensor(142155.7812, grad_fn=)\n",
+ "tensor(142155.7812, grad_fn=)\n",
+ "tensor(142155.7812, grad_fn=)\n",
+ "tensor(142155.7812, grad_fn=)\n",
+ "tensor(142155.7812, grad_fn=)\n",
+ "tensor(142155.7969, grad_fn=)\n",
+ "tensor(142155.7812, grad_fn=)\n",
+ "tensor(142155.8281, grad_fn=)\n",
+ "tensor(142155.8125, grad_fn=)\n",
+ "tensor(142155.8906, grad_fn=)\n",
+ "tensor(142155.9688, grad_fn=)\n",
+ "tensor(142155.7969, grad_fn=)\n",
+ "2\n",
+ "tensor(143104.5469, grad_fn=)\n",
+ "tensor(142385.0469, grad_fn=)\n",
+ "tensor(142384.3594, grad_fn=)\n",
+ "tensor(142384.3594, grad_fn=)\n",
+ "tensor(142384.3594, grad_fn=)\n",
+ "tensor(142384.3594, grad_fn=)\n",
+ "tensor(142384.3594, grad_fn=)\n",
+ "tensor(142384.3594, grad_fn=)\n",
+ "tensor(142384.3594, grad_fn=)\n",
+ "tensor(142384.3750, grad_fn=)\n",
+ "tensor(142384.3906, grad_fn=)\n",
+ "tensor(142384.3594, grad_fn=)\n",
+ "tensor(142384.5156, grad_fn=)\n",
+ "tensor(142384.3750, grad_fn=)\n",
+ "tensor(142384.3906, grad_fn=)\n",
+ "tensor(142384.3594, grad_fn=)\n",
+ "tensor(142384.4062, grad_fn=)\n",
+ "tensor(142384.4375, grad_fn=)\n",
+ "tensor(142384.4062, grad_fn=)\n",
+ "tensor(142384.3906, grad_fn=)\n",
+ "3\n",
+ "tensor(142825.3125, grad_fn=)\n",
+ "tensor(142005.4062, grad_fn=)\n",
+ "tensor(142000.4844, grad_fn=)\n",
+ "tensor(142000.4688, grad_fn=)\n",
+ "tensor(142000.4688, grad_fn=)\n",
+ "tensor(142000.4688, grad_fn=)\n",
+ "tensor(142000.4688, grad_fn=)\n",
+ "tensor(142000.4688, grad_fn=)\n",
+ "tensor(142000.5000, grad_fn=)\n",
+ "tensor(142000.4844, grad_fn=)\n",
+ "tensor(142000.4844, grad_fn=)\n",
+ "tensor(142000.5000, grad_fn=)\n",
+ "tensor(142000.5156, grad_fn=)\n",
+ "tensor(142000.4844, grad_fn=)\n",
+ "tensor(142000.5156, grad_fn=)\n",
+ "tensor(142000.4844, grad_fn=)\n",
+ "tensor(142000.7188, grad_fn=)\n",
+ "tensor(142000.5938, grad_fn=)\n",
+ "tensor(142000.6094, grad_fn=)\n",
+ "tensor(142000.5000, grad_fn=)\n",
+ "4\n",
+ "tensor(143287.3281, grad_fn=)\n",
+ "tensor(142508.1094, grad_fn=)\n",
+ "tensor(142506.4531, grad_fn=)\n"
+ ]
+ },
+ {
+ "ename": "KeyboardInterrupt",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[13], line 20\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m j \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m2000\u001b[39m):\n\u001b[1;32m 19\u001b[0m adam\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m---> 20\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43melbo\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m j \u001b[38;5;241m%\u001b[39m \u001b[38;5;241m100\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 22\u001b[0m \u001b[38;5;28mprint\u001b[39m(loss)\n",
+ "File \u001b[0;32m~/opt/anaconda3/envs/chirho-dynamic/lib/python3.11/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
+ "File \u001b[0;32m~/opt/anaconda3/envs/chirho-dynamic/lib/python3.11/site-packages/pyro/infer/elbo.py:25\u001b[0m, in \u001b[0;36mELBOModule.forward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m---> 25\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43melbo\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdifferentiable_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mguide\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/opt/anaconda3/envs/chirho-dynamic/lib/python3.11/site-packages/pyro/infer/trace_elbo.py:121\u001b[0m, in \u001b[0;36mTrace_ELBO.differentiable_loss\u001b[0;34m(self, model, guide, *args, **kwargs)\u001b[0m\n\u001b[1;32m 119\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0.0\u001b[39m\n\u001b[1;32m 120\u001b[0m surrogate_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0.0\u001b[39m\n\u001b[0;32m--> 121\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m model_trace, guide_trace \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_traces(model, guide, args, kwargs):\n\u001b[1;32m 122\u001b[0m loss_particle, surrogate_loss_particle \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_differentiable_loss_particle(\n\u001b[1;32m 123\u001b[0m model_trace, guide_trace\n\u001b[1;32m 124\u001b[0m )\n\u001b[1;32m 125\u001b[0m surrogate_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m surrogate_loss_particle \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_particles\n",
+ "File \u001b[0;32m~/opt/anaconda3/envs/chirho-dynamic/lib/python3.11/site-packages/pyro/infer/elbo.py:237\u001b[0m, in \u001b[0;36mELBO._get_traces\u001b[0;34m(self, model, guide, args, kwargs)\u001b[0m\n\u001b[1;32m 235\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 236\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_particles):\n\u001b[0;32m--> 237\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_trace\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mguide\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/opt/anaconda3/envs/chirho-dynamic/lib/python3.11/site-packages/pyro/infer/trace_elbo.py:57\u001b[0m, in \u001b[0;36mTrace_ELBO._get_trace\u001b[0;34m(self, model, guide, args, kwargs)\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_get_trace\u001b[39m(\u001b[38;5;28mself\u001b[39m, model, guide, args, kwargs):\n\u001b[1;32m 53\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 54\u001b[0m \u001b[38;5;124;03m Returns a single trace from the guide, and the model that is run\u001b[39;00m\n\u001b[1;32m 55\u001b[0m \u001b[38;5;124;03m against it.\u001b[39;00m\n\u001b[1;32m 56\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 57\u001b[0m model_trace, guide_trace \u001b[38;5;241m=\u001b[39m \u001b[43mget_importance_trace\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 58\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mflat\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax_plate_nesting\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mguide\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 59\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 60\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_validation_enabled():\n\u001b[1;32m 61\u001b[0m check_if_enumerated(guide_trace)\n",
+ "File \u001b[0;32m~/opt/anaconda3/envs/chirho-dynamic/lib/python3.11/site-packages/pyro/infer/enum.py:75\u001b[0m, in \u001b[0;36mget_importance_trace\u001b[0;34m(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)\u001b[0m\n\u001b[1;32m 72\u001b[0m guide_trace \u001b[38;5;241m=\u001b[39m prune_subsample_sites(guide_trace)\n\u001b[1;32m 73\u001b[0m model_trace \u001b[38;5;241m=\u001b[39m prune_subsample_sites(model_trace)\n\u001b[0;32m---> 75\u001b[0m \u001b[43mmodel_trace\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_log_prob\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 76\u001b[0m guide_trace\u001b[38;5;241m.\u001b[39mcompute_score_parts()\n\u001b[1;32m 77\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_validation_enabled():\n",
+ "File \u001b[0;32m~/opt/anaconda3/envs/chirho-dynamic/lib/python3.11/site-packages/pyro/poutine/trace_struct.py:230\u001b[0m, in \u001b[0;36mTrace.compute_log_prob\u001b[0;34m(self, site_filter)\u001b[0m\n\u001b[1;32m 228\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlog_prob\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m site:\n\u001b[1;32m 229\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 230\u001b[0m log_p \u001b[38;5;241m=\u001b[39m \u001b[43msite\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mfn\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlog_prob\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 231\u001b[0m \u001b[43m \u001b[49m\u001b[43msite\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mvalue\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43msite\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43margs\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43msite\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mkwargs\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[1;32m 232\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 233\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 234\u001b[0m _, exc_value, traceback \u001b[38;5;241m=\u001b[39m sys\u001b[38;5;241m.\u001b[39mexc_info()\n",
+ "File \u001b[0;32m~/opt/anaconda3/envs/chirho-dynamic/lib/python3.11/site-packages/torch/distributions/independent.py:99\u001b[0m, in \u001b[0;36mIndependent.log_prob\u001b[0;34m(self, value)\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mlog_prob\u001b[39m(\u001b[38;5;28mself\u001b[39m, value):\n\u001b[0;32m---> 99\u001b[0m log_prob \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbase_dist\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlog_prob\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 100\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _sum_rightmost(log_prob, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreinterpreted_batch_ndims)\n",
+ "File \u001b[0;32m~/opt/anaconda3/envs/chirho-dynamic/lib/python3.11/site-packages/torch/distributions/normal.py:83\u001b[0m, in \u001b[0;36mNormal.log_prob\u001b[0;34m(self, value)\u001b[0m\n\u001b[1;32m 81\u001b[0m var \u001b[38;5;241m=\u001b[39m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscale \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m \u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 82\u001b[0m log_scale \u001b[38;5;241m=\u001b[39m math\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscale) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscale, Real) \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscale\u001b[38;5;241m.\u001b[39mlog()\n\u001b[0;32m---> 83\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;241m-\u001b[39m(\u001b[43m(\u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloc\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m) \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m var) \u001b[38;5;241m-\u001b[39m log_scale \u001b[38;5;241m-\u001b[39m math\u001b[38;5;241m.\u001b[39mlog(math\u001b[38;5;241m.\u001b[39msqrt(\u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m math\u001b[38;5;241m.\u001b[39mpi))\n",
+ "File \u001b[0;32m~/opt/anaconda3/envs/chirho-dynamic/lib/python3.11/site-packages/torch/_tensor.py:34\u001b[0m, in \u001b[0;36m_handle_torch_function_and_wrap_type_error_to_not_implemented..wrapped\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_handle_torch_function_and_wrap_type_error_to_not_implemented\u001b[39m(f):\n\u001b[1;32m 32\u001b[0m assigned \u001b[38;5;241m=\u001b[39m functools\u001b[38;5;241m.\u001b[39mWRAPPER_ASSIGNMENTS\n\u001b[0;32m---> 34\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(f, assigned\u001b[38;5;241m=\u001b[39massigned)\n\u001b[1;32m 35\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapped\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 36\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 37\u001b[0m \u001b[38;5;66;03m# See https://github.com/pytorch/pytorch/issues/75462\u001b[39;00m\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function(args):\n",
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
+ ]
+ }
+ ],
+ "source": [
+ "fitted_params = []\n",
+ "for i in range(N_datasets):\n",
+ " print(i)\n",
+ "\n",
+ " # Fit model using maximum likelihood\n",
+ " conditioned_model = ConditionedCausalGLM(\n",
+ " X=D_train[\"X\"][i], A=D_train[\"A\"][i], Y=D_train[\"Y\"][i]\n",
+ " )\n",
+ " \n",
+ " guide_train = pyro.infer.autoguide.AutoDelta(conditioned_model)\n",
+ " elbo = pyro.infer.Trace_ELBO()(conditioned_model, guide_train)\n",
+ "\n",
+ " # initialize parameters\n",
+ " elbo()\n",
+ " adam = torch.optim.Adam(elbo.parameters(), lr=0.03)\n",
+ "\n",
+ " # Do gradient steps\n",
+ " for j in range(2000):\n",
+ " adam.zero_grad()\n",
+ " loss = elbo()\n",
+ " if j % 100 == 0:\n",
+ " print(loss)\n",
+ " loss.backward()\n",
+ " adam.step()\n",
+ "\n",
+ " theta_hat = {\n",
+ " k: v.clone().detach().requires_grad_(True) for k, v in guide_train().items()\n",
+ " }\n",
+ " fitted_params.append(theta_hat)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Causal Query: Sample Average treatment effect (ATE)\n",
+ "\n",
+ "The average treatment effect summarizes, on average, how much the treatment changes the response, $ATE = \\mathbb{E}[Y|do(A=1)] - \\mathbb{E}[Y|do(A=0)]$. The `do` notation indicates that the expectations are taken according to *intervened* versions of the model, with $A$ set to a particular value. Note from our [tutorial](tutorial_i.ipynb) that this is different from conditioning on $A$ in the original `causal_model`, which assumes $X$ and $T$ are dependent.\n",
+ "\n",
+ "\n",
+ "To implement this query in ChiRho, we define the `SATEFunctional` class which take in a `model` and `guide` and returns the average treatment effect by simulating from the posterior predictive distribution of the model and guide."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Defining the target functional"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class SATEFunctional(torch.nn.Module):\n",
+ " def __init__(self, model: Callable, *, num_monte_carlo: int = 100):\n",
+ " super().__init__()\n",
+ " self.model = model\n",
+ " self.num_monte_carlo = num_monte_carlo\n",
+ " \n",
+ " def forward(self, *args, **kwargs):\n",
+ " with MultiWorldCounterfactual():\n",
+ " with do(actions=dict(A=(torch.tensor(0.0), torch.tensor(1.0)))):\n",
+ " Ys = self.model(*args, **kwargs)\n",
+ " Y0 = gather(Ys, IndexSet(A={1}), event_dim=0)\n",
+ " Y1 = gather(Ys, IndexSet(A={2}), event_dim=0)\n",
+ " sate = (Y1 - Y0).mean(dim=-1, keepdim=True).squeeze()\n",
+ " return pyro.deterministic(\"SATE\", sate)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "SATE = SATEFunctional(true_model)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Closed form doubly robust correction\n",
+ "\n",
+ "For the average treatment effect functional, there exists a closed-form analytical formula for the doubly robust correction. This formula is derived in Kennedy (2022) and is implemented below:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from typing import Any\n",
+ "from chirho.robust.ops import Functional, Point, P, S, T\n",
+ "\n",
+ "def SATECausalGLM_analytic_influence(functional: Functional[P, S], \n",
+ " point: Point[T], \n",
+ " pointwise_influence: bool = True,\n",
+ " **kwargs) -> Functional[P, S]:\n",
+ " # assert isinstance(functional, SATEFunctional)\n",
+ " def new_functional(model: Callable[P, Any]) -> Callable[P, S]:\n",
+ " assert isinstance(model.model, CausalGLM)\n",
+ " theta = dict(model.guide.named_parameters())\n",
+ " def new_model(*args, **kwargs):\n",
+ " X = point[\"X\"]\n",
+ " A = point[\"A\"]\n",
+ " Y = point[\"Y\"]\n",
+ " \n",
+ " pi_X = torch.sigmoid(torch.einsum(\"...i,...i->...\", X, theta[\"propensity_weights_param\"]))\n",
+ " mu_X = (\n",
+ " torch.einsum(\"...i,...i->...\", X, theta[\"outcome_weights_param\"])\n",
+ " + A * theta[\"treatment_weight_param\"]\n",
+ " + theta[\"intercept_param\"]\n",
+ " )\n",
+ " analytic_eif_at_pts = (A / pi_X - (1 - A) / (1 - pi_X)) * (Y - mu_X)\n",
+ " if pointwise_influence:\n",
+ " return analytic_eif_at_pts\n",
+ " else:\n",
+ " return analytic_eif_at_pts.mean()\n",
+ "\n",
+ "\n",
+ " return new_model\n",
+ " return new_functional\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# # Closed form expression\n",
+ "# def closed_form_doubly_robust_ate_correction(X_test, theta) -> Tuple[torch.Tensor, torch.Tensor]:\n",
+ "# X = X_test[\"X\"]\n",
+ "# A = X_test[\"A\"]\n",
+ "# Y = X_test[\"Y\"]\n",
+ "# pi_X = torch.sigmoid(X.mv(theta[\"propensity_weights\"]))\n",
+ "# mu_X = (\n",
+ "# X.mv(theta[\"outcome_weights\"])\n",
+ "# + A * theta[\"treatment_weight\"]\n",
+ "# + theta[\"intercept\"]\n",
+ "# )\n",
+ "# analytic_eif_at_test_pts = (A / pi_X - (1 - A) / (1 - pi_X)) * (Y - mu_X)\n",
+ "# analytic_correction = analytic_eif_at_test_pts.mean()\n",
+ "# return analytic_correction, analytic_eif_at_test_pts"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Computing automated doubly robust correction via Monte Carlo\n",
+ "\n",
+ "While the doubly robust correction term is known in closed-form for the average treatment effect functional, our `one_step_correction` function in `ChiRho` works for a wide class of other functionals. We focus on the average treatment effect functional here so that we have a ground truth to compare `one_step_correction` against."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import tracemalloc\n",
+ "\n",
+ "tracemalloc.start()\n",
+ "\n",
+ "# Helper class to create a trivial guide that returns the maximum likelihood estimate\n",
+ "class MLEGuide(torch.nn.Module):\n",
+ " def __init__(self, mle_est: ParamDict):\n",
+ " super().__init__()\n",
+ " self.names = list(mle_est.keys())\n",
+ " for name, value in mle_est.items():\n",
+ " setattr(self, name + \"_param\", torch.nn.Parameter(value))\n",
+ "\n",
+ " def forward(self, *args, **kwargs):\n",
+ " for name in self.names:\n",
+ " value = getattr(self, name + \"_param\")\n",
+ " pyro.sample(\n",
+ " name, pyro.distributions.Delta(value, event_dim=len(value.shape))\n",
+ " )\n",
+ "\n",
+ "# Compute doubly robust ATE estimates using both the automated and closed form expressions\n",
+ "# estimators = {\"tmle\": tmle, \"one_step\": one_step_corrected_estimator}\n",
+ "estimators = {\"one_step\": one_step_corrected_estimator}\n",
+ "estimator_kwargs = {\"tmle\": {\"learning_rate\":5e-5,\n",
+ " \"n_grad_steps\":500,\n",
+ " \"n_tmle_steps\":1,\n",
+ " \"num_nmc_samples\":1000,\n",
+ " \"num_grad_samples\":N_test}, \"one_step\": {}}\n",
+ "# influences = {\"analytic\": SATECausalGLM_analytic_influence, \"monte_carlo\": influence_fn}\n",
+ "influences = {\"analytic\": SATECausalGLM_analytic_influence}\n",
+ "\n",
+ "estimates = {f\"{influence}-{estimator}\": torch.zeros(N_datasets) for influence in influences.keys() for estimator in estimators.keys()}\n",
+ "estimates[\"plug-in-mle\"] = torch.zeros(N_datasets)\n",
+ "estimates[\"plug-in-prior\"] = torch.zeros(N_datasets)\n",
+ "estimates[\"plug-in-truth\"] = torch.zeros(N_datasets)\n",
+ "\n",
+ "functional = functools.partial(SATEFunctional, num_monte_carlo=10000)\n",
+ "\n",
+ "for i in range(N_datasets):\n",
+ " print(\"plug-in-prior\", i)\n",
+ " estimates[\"plug-in-prior\"][i] = functional(prior_model)().item()\n",
+ "\n",
+ " print(\"plug-in-truth\", i)\n",
+ " estimates[\"plug-in-truth\"][i] = functional(true_model)().item()\n",
+ "\n",
+ " # D_test = simulated_datasets[i][1]\n",
+ " theta_hat = fitted_params[i]\n",
+ " mle_guide = MLEGuide(theta_hat)\n",
+ "\n",
+ " model = PredictiveModel(CausalGLM(p, N_test), mle_guide)\n",
+ " \n",
+ " print(\"plug-in-mle\", i)\n",
+ " estimates[\"plug-in-mle\"][i] = functional(model)().detach().item()\n",
+ "\n",
+ " for estimator_str, estimator in estimators.items():\n",
+ " for influence_str, influence in influences.items():\n",
+ " if estimator_str == \"tmle\" and influence_str == \"monte_carlo\":\n",
+ " continue\n",
+ "\n",
+ " print(estimator_str, influence_str, i)\n",
+ " \n",
+ " estimate = estimator(\n",
+ " functional, \n",
+ " {\"X\": D_test[\"X\"][i], \"A\": D_test[\"A\"][i], \"Y\": D_test[\"Y\"][i]},\n",
+ " num_samples_outer=max(10000, 100 * p), \n",
+ " num_samples_inner=1,\n",
+ " influence_estimator=influence,\n",
+ " **estimator_kwargs[estimator_str]\n",
+ " )(model)()\n",
+ "\n",
+ " estimates[f\"{influence_str}-{estimator_str}\"][i] = estimate.detach().item()\n",
+ "\n",
+ " print(tracemalloc.get_traced_memory())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "results = pd.DataFrame(estimates)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Results"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# The true treatment effect is 0, so a mean estimate closer to zero is better\n",
+ "results.describe().round(2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Visualize the results\n",
+ "fig, ax = plt.subplots()\n",
+ "\n",
+ "for col in results.columns:\n",
+ " sns.kdeplot(results[col], ax=ax, label=col)\n",
+ "\n",
+ "ax.axvline(treatment_weight, color=\"black\", label=\"True ATE\", linestyle=\"--\")\n",
+ "ax.set_yticks([])\n",
+ "sns.despine()\n",
+ "ax.legend(loc=\"upper right\")\n",
+ "ax.set_xlabel(\"ATE Estimate\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# plt.scatter(\n",
+ "# results['automated_monte_carlo_correction'],\n",
+ "# results['analytic_correction'],\n",
+ "# )\n",
+ "# plt.plot(np.linspace(-.2, .5), np.linspace(-.2, .5), color=\"black\", linestyle=\"dashed\")\n",
+ "# plt.xlabel(\"DR-Monte Carlo\")\n",
+ "# plt.ylabel(\"DR-Analytic\")\n",
+ "# sns.despine()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## References\n",
+ "\n",
+ "Kennedy, Edward. \"Towards optimal doubly robust estimation of heterogeneous causal effects\", 2022. https://arxiv.org/abs/2004.14497."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "basis",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.4"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/setup.py b/setup.py
index 47c6dcd3d..0b425f72a 100644
--- a/setup.py
+++ b/setup.py
@@ -25,6 +25,7 @@
]
DYNAMICAL_REQUIRE = ["torchdiffeq"]
+ROBUST_REQUIRE = ["torchopt"]
setup(
name="chirho",
@@ -45,8 +46,9 @@
],
extras_require={
"dynamical": DYNAMICAL_REQUIRE,
+ "robust": ROBUST_REQUIRE,
"extras": EXTRAS_REQUIRE,
- "test": EXTRAS_REQUIRE + DYNAMICAL_REQUIRE
+ "test": EXTRAS_REQUIRE + DYNAMICAL_REQUIRE + ROBUST_REQUIRE
+ [
"pytest",
"pytest-cov",
diff --git a/tests/robust/test_handlers.py b/tests/robust/test_handlers.py
index e43015282..9cd55e749 100644
--- a/tests/robust/test_handlers.py
+++ b/tests/robust/test_handlers.py
@@ -1,3 +1,4 @@
+import copy
import functools
from typing import Callable, List, Mapping, Optional, Set, Tuple, TypeVar
@@ -6,7 +7,7 @@
import torch
from typing_extensions import ParamSpec
-from chirho.robust.handlers.estimators import one_step_corrected_estimator
+from chirho.robust.handlers.estimators import one_step_corrected_estimator, tmle
from chirho.robust.handlers.predictive import PredictiveFunctional, PredictiveModel
from .robust_fixtures import SimpleGuide, SimpleModel
@@ -42,7 +43,7 @@
@pytest.mark.parametrize("num_samples_outer,num_samples_inner", [(10, None), (10, 100)])
@pytest.mark.parametrize("cg_iters", [None, 1, 10])
@pytest.mark.parametrize("num_predictive_samples", [1, 5])
-@pytest.mark.parametrize("estimation_method", [one_step_corrected_estimator])
+@pytest.mark.parametrize("estimation_method", [one_step_corrected_estimator, tmle])
def test_estimator_smoke(
model,
guide,
@@ -66,6 +67,20 @@ def test_estimator_smoke(
)().items()
}
+ predictive_model = PredictiveModel(model, guide)
+
+ prev_params = copy.deepcopy(dict(predictive_model.named_parameters()))
+
+ if estimation_method == tmle:
+ estimator_kwargs = {
+ "n_tmle_steps": 1,
+ "n_grad_steps": 2,
+ "num_nmc_samples": 10,
+ "num_grad_samples": 10,
+ }
+ else:
+ estimator_kwargs = {}
+
estimator = estimation_method(
functools.partial(PredictiveFunctional, num_samples=num_predictive_samples),
test_datum,
@@ -73,7 +88,8 @@ def test_estimator_smoke(
num_samples_outer=num_samples_outer,
num_samples_inner=num_samples_inner,
cg_iters=cg_iters,
- )(PredictiveModel(model, guide))
+ **estimator_kwargs,
+ )(predictive_model)
estimate_on_test: Mapping[str, torch.Tensor] = estimator()
assert len(estimate_on_test) > 0
@@ -83,3 +99,8 @@ def test_estimator_smoke(
assert not torch.isclose(
v, torch.zeros_like(v)
).all(), f"{estimation_method} estimator for {k} was zero"
+
+ # Assert estimator doesn't have side effects on model parameters.
+ new_params = dict(predictive_model.named_parameters())
+ for k, v in prev_params.items():
+ assert torch.allclose(v, new_params[k]), f"{k} was updated"