diff --git a/chirho/robust/handlers/estimators.py b/chirho/robust/handlers/estimators.py index eb6e8d6ee..779b108e8 100644 --- a/chirho/robust/handlers/estimators.py +++ b/chirho/robust/handlers/estimators.py @@ -1,8 +1,13 @@ +import copy +import warnings from typing import Any, Callable, TypeVar import torch +import torchopt from typing_extensions import ParamSpec +from chirho.robust.handlers.predictive import PredictiveFunctional +from chirho.robust.internals.utils import make_functional_call from chirho.robust.ops import Functional, Point, influence_fn P = ParamSpec("P") @@ -10,9 +15,216 @@ T = TypeVar("T") +def tmle_scipy_optimize_wrapper( + packed_influence, log_jitter: float = 1e-6 +) -> torch.Tensor: + import numpy as np + import scipy + from scipy.optimize import LinearConstraint + + # Turn things into numpy. This makes us sad... :( + D = packed_influence.detach().numpy() + + N, L = D.shape[0], D.shape[1] + + def loss(epsilon): + correction = 1 + D.dot(epsilon) + + return -np.sum(np.log(np.maximum(correction, log_jitter))) + + positive_density_constraint = LinearConstraint( + D, -1 * np.ones(N), np.inf * np.ones(N) + ) + + epsilon_solve = scipy.optimize.minimize( + loss, np.zeros(L, dtype=D.dtype), constraints=positive_density_constraint + ) + + if not epsilon_solve.success: + warnings.warn("TMLE optimization did not converge.", RuntimeWarning) + + # Convert epsilon back to torch. This makes us happy... :) + packed_epsilon = torch.tensor(epsilon_solve.x, dtype=packed_influence.dtype) + + return packed_epsilon + + +# TODO: revert influence_estimator to influence_fn and use handlers for influence_fn +def tmle( + functional: Functional[P, S], + test_point: Point, + learning_rate: float = 1e-5, + n_grad_steps: int = 100, + n_tmle_steps: int = 1, + num_nmc_samples: int = 1000, + num_grad_samples: int = 1000, + log_jitter: float = 1e-6, + verbose: bool = False, + influence_estimator: Callable[ + [Functional[P, S], Point[T]], Functional[P, S] + ] = influence_fn, + **influence_kwargs, +) -> Functional[P, S]: + from chirho.robust.internals.nmc import BatchedNMCLogMarginalLikelihood + + def _solve_epsilon(prev_model: torch.nn.Module, *args, **kwargs) -> torch.Tensor: + # find epsilon that minimizes the corrected density on test data + + influence_at_test = influence_estimator( + functional, test_point, **influence_kwargs + )(prev_model)(*args, **kwargs) + + flat_influence_at_test, _ = torch.utils._pytree.tree_flatten(influence_at_test) + + N = flat_influence_at_test[0].shape[0] + + packed_influence_at_test = torch.concatenate( + [i.reshape(N, -1) for i in flat_influence_at_test] + ) + + packed_epsilon = tmle_scipy_optimize_wrapper(packed_influence_at_test) + + return packed_epsilon + + def _solve_model_projection( + packed_epsilon: torch.Tensor, + prev_model: torch.nn.Module, + *args, + **kwargs, + ) -> torch.nn.Module: + prev_params, functional_model = make_functional_call( + PredictiveFunctional(prev_model, num_samples=num_grad_samples) + ) + prev_params = {k: v.detach() for k, v in prev_params.items()} + + # Sample data from the model. Note that we only sample once during projection. + data = { + k: v + for k, v in functional_model(prev_params, *args, **kwargs).items() + if k in test_point + } + + batched_log_prob: torch.nn.Module = BatchedNMCLogMarginalLikelihood( + prev_model, num_samples=num_nmc_samples + ) + + _, log_p_phi = make_functional_call(batched_log_prob) + + influence_at_data = influence_estimator(functional, data, **influence_kwargs)( + prev_model + )(*args, **kwargs) + flat_influence_at_data, _ = torch.utils._pytree.tree_flatten(influence_at_data) + N_x = flat_influence_at_data[0].shape[0] + + packed_influence_at_data = torch.concatenate( + [i.reshape(N_x, -1) for i in flat_influence_at_data] + ).detach() + + log_likelihood_correction = torch.log( + torch.maximum( + 1 + packed_influence_at_data.mv(packed_epsilon), + torch.tensor(log_jitter), + ) + ).detach() + if verbose: + influence_at_test = influence_estimator( + functional, test_point, **influence_kwargs + )(prev_model)(*args, **kwargs) + flat_influence_at_test, _ = torch.utils._pytree.tree_flatten( + influence_at_test + ) + N = flat_influence_at_test[0].shape[0] + + packed_influence_at_test = torch.concatenate( + [i.reshape(N, -1) for i in flat_influence_at_test] + ).detach() + + log_likelihood_correction_at_test = torch.log( + torch.maximum( + 1 + packed_influence_at_test.mv(packed_epsilon), + torch.tensor(log_jitter), + ) + ) + + print("previous log prob at test", log_p_phi(prev_params, test_point).sum()) + print( + "new log prob at test", + ( + log_p_phi(prev_params, test_point) + + log_likelihood_correction_at_test + ).sum(), + ) + + log_p_epsilon_at_data = ( + log_likelihood_correction + log_p_phi(prev_params, data) + ).detach() + + def loss(new_params): + log_p_phi_at_data = log_p_phi(new_params, data) + return torch.sum((log_p_phi_at_data - log_p_epsilon_at_data) ** 2) + + grad_fn = torch.func.grad(loss) + + new_params = { + k: v.clone().detach().requires_grad_(True) for k, v in prev_params.items() + } + + optimizer = torchopt.adam(lr=learning_rate) + + optimizer_state = optimizer.init(new_params) + + for i in range(n_grad_steps): + grad = grad_fn(new_params) + if verbose and i % 100 == 0: + print(f"inner_iteration_{i}_loss", loss(new_params)) + for parameter_name, parameter in prev_model.named_parameters(): + parameter.data = new_params[f"model.{parameter_name}"] + + estimate = functional(prev_model)(*args, **kwargs) + assert isinstance(estimate, torch.Tensor) + print( + f"inner_iteration_{i}_estimate", + estimate.detach().item(), + ) + updates, optimizer_state = optimizer.update( + grad, optimizer_state, inplace=False + ) + new_params = torchopt.apply_updates(new_params, updates) + + for parameter_name, parameter in prev_model.named_parameters(): + parameter.data = new_params[f"model.{parameter_name}"] + + return prev_model + + def _corrected_functional(*models: Callable[P, Any]) -> Callable[P, S]: + assert len(models) == 1 + model = models[0] + + assert isinstance(model, torch.nn.Module) + + def _estimator(*args, **kwargs) -> S: + tmle_model = copy.deepcopy(model) + + for _ in range(n_tmle_steps): + packed_epsilon = _solve_epsilon(tmle_model, *args, **kwargs) + + tmle_model = _solve_model_projection( + packed_epsilon, tmle_model, *args, **kwargs + ) + return functional(tmle_model)(*args, **kwargs) + + return _estimator + + return _corrected_functional + + +# TODO: revert influence_estimator to influence_fn and use handlers for influence_fn def one_step_corrected_estimator( functional: Functional[P, S], *test_points: Point[T], + influence_estimator: Callable[ + [Functional[P, S], Point[T]], Functional[P, S] + ] = influence_fn, **influence_kwargs, ) -> Functional[P, S]: """ @@ -30,7 +242,7 @@ def one_step_corrected_estimator( """ influence_kwargs_one_step = influence_kwargs.copy() influence_kwargs_one_step["pointwise_influence"] = False - eif_fn = influence_fn(functional, *test_points, **influence_kwargs_one_step) + eif_fn = influence_estimator(functional, *test_points, **influence_kwargs_one_step) def _corrected_functional(*model: Callable[P, Any]) -> Callable[P, S]: plug_in_estimator = functional(*model) diff --git a/docs/source/tmle.ipynb b/docs/source/tmle.ipynb new file mode 100644 index 000000000..b9fe30e49 --- /dev/null +++ b/docs/source/tmle.ipynb @@ -0,0 +1,930 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Automated doubly robust estimation with ChiRho - TMLE Version" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Outline\n", + "\n", + "- [Setup](#setup)\n", + "\n", + "- [Overview: Systematically adjusting for observed confounding](#overview:-systematically-adjusting-for-observed-confounding)\n", + " - [Task: Treatment effect estimation with observational data](#task:-treatment-effect-estimation-with-observational-data)\n", + " - [Challenge: Confounding](#challenge:-confounding)\n", + " - [Assumptions: All confounders observed](#assumptions:-all-confounders-observed)\n", + " - [Intuition: Statistically adjusting for confounding](#intuition:-statistically-adjusting-for-confounding)\n", + "\n", + "- [Causal Probabilistic Program](#causal-probabilistic-program)\n", + " - [Model description](#model-description)\n", + " - [Generating data](#generating-data)\n", + " - [Fit parameters via maximum likelihood](#fit-parameters-via-maximum-likelihood)\n", + "\n", + "- [Causal Query: average treatment effect (ATE)](#causal-query:-average-treatment-effect-\\(ATE\\))\n", + " - [Defining the target functional](#defining-the-target-functional)\n", + " - [Closed form doubly robust correction](#closed-form-doubly-robust-correction)\n", + " - [Computing automated doubly robust correction via Monte Carlo](#computing-automated-doubly-robust-correction-via-monte-carlo)\n", + " - [Results](#results)\n", + "\n", + "- [References](#references)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here, we install the necessary Pytorch, Pyro, and ChiRho dependencies for this example." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "NOTE: Redirects are currently not supported in Windows or MacOs.\n" + ] + } + ], + "source": [ + "from typing import Callable, Optional, Tuple\n", + "\n", + "import functools\n", + "import torch\n", + "import math\n", + "import seaborn as sns\n", + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import pyro\n", + "import pyro.distributions as dist\n", + "from pyro.infer import Predictive\n", + "import pyro.contrib.gp as gp\n", + "\n", + "from chirho.counterfactual.handlers import MultiWorldCounterfactual\n", + "from chirho.indexed.ops import IndexSet, gather\n", + "from chirho.interventional.handlers import do\n", + "from chirho.robust.internals.utils import ParamDict\n", + "from chirho.robust.handlers.estimators import one_step_corrected_estimator, tmle\n", + "from chirho.robust.handlers.predictive import PredictiveModel \n", + "from chirho.robust.ops import influence_fn\n", + "\n", + "pyro.settings.set(module_local_params=True)\n", + "\n", + "sns.set_style(\"white\")\n", + "\n", + "pyro.set_rng_seed(321) # for reproducibility" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Overview\n", + "\n", + "In this tutorial, we will use ChiRho to estimate the average treatment effect (ATE) from observational data. We will use a simple example to illustrate the basic concepts of doubly robust estimation and how ChiRho can be used to automate the process for more general summaries of interest. \n", + "\n", + "There are five main steps to our doubly robust estimation procedure but only the last step is different from a standard probabilistic programming workflow:\n", + "1. Write model of interest\n", + " - Define probabilistic model of interest using Pyro\n", + "2. Feed in data\n", + " - Observed data used to train the model\n", + "3. Run inference\n", + " - Use Pyro's rich inference library to fit the model to the data\n", + "4. Define target functional\n", + " - This is the model summary of interest (e.g. average treatment effect)\n", + "5. Compute robust estimate\n", + " - Use ChiRho to compute the doubly robust estimate of the target functional\n", + " - Importantly, this step is automated and does not require refitting the model for each new functional\n", + "\n", + "\n", + "Our proposed automated robust inference pipeline is summarized in the figure below.\n", + "\n", + "![fig1](figures/robust_pipeline.png)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Causal Probabilistic Program\n", + "\n", + "### Model Description\n", + "In this example, we will focus on a cannonical model `CausalGLM` consisting of three types of variables: binary treatment (`A`), confounders (`X`), and response (`Y`). For simplicitly, we assume that the response is generated from a generalized linear model with link function $g$. The model is described by the following generative process:\n", + "\n", + "$$\n", + "\\begin{align*}\n", + "X &\\sim \\text{Normal}(0, I_p) \\\\\n", + "A &\\sim \\text{Bernoulli}(\\pi(X)) \\\\\n", + "\\mu &= \\beta_0 + \\beta_1^T X + \\tau A \\\\\n", + "Y &\\sim \\text{ExponentialFamily}(\\text{mean} = g^{-1}(\\mu))\n", + "\\end{align*}\n", + "$$\n", + "\n", + "where $p$ denotes the number of confounders, $\\pi(X)$ is the probability of treatment conditional on confounders $X$, $\\beta_0$ is the intercept, $\\beta_1$ is the confounder effect, and $\\tau$ is the treatment effect." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "class CausalGLM(pyro.nn.PyroModule):\n", + " def __init__(\n", + " self,\n", + " p: int,\n", + " N: int,\n", + " link_fn: Callable[..., dist.Distribution] = lambda mu: dist.Normal(mu, 1.0),\n", + " include_prior: bool = True,\n", + " prior_scale: Optional[float] = None,\n", + " ):\n", + " super().__init__()\n", + " self.p = p\n", + " self.N = N\n", + " self.link_fn = link_fn\n", + " self.include_prior = include_prior\n", + " if prior_scale is None:\n", + " self.prior_scale = 1 / math.sqrt(self.p)\n", + " else:\n", + " self.prior_scale = prior_scale\n", + "\n", + " def sample_outcome_weights(self):\n", + " return pyro.sample(\n", + " \"outcome_weights\",\n", + " dist.Normal(0.0, self.prior_scale).expand((self.p,)).to_event(1),\n", + " )\n", + "\n", + " def sample_intercept(self):\n", + " return pyro.sample(\"intercept\", dist.Normal(0.0, 1.0))\n", + "\n", + " def sample_propensity_weights(self):\n", + " return pyro.sample(\n", + " \"propensity_weights\",\n", + " dist.Normal(0.0, self.prior_scale).expand((self.p,)).to_event(1),\n", + " )\n", + "\n", + " def sample_treatment_weight(self):\n", + " return pyro.sample(\"treatment_weight\", dist.Normal(0.0, 1.0))\n", + "\n", + " def sample_covariate_loc_scale(self):\n", + " return torch.zeros(self.p), torch.ones(self.p)\n", + " \n", + " def generate_datum(self, x_loc, x_scale, propensity_weights, outcome_weights, tau, intercept):\n", + " X = pyro.sample(\"X\", dist.Normal(x_loc, x_scale).to_event(1))\n", + " A = pyro.sample(\n", + " \"A\",\n", + " dist.Bernoulli(\n", + " logits=torch.einsum(\"...i,...i->...\", X, propensity_weights)\n", + " ),\n", + " )\n", + " return pyro.sample(\n", + " \"Y\",\n", + " self.link_fn(\n", + " torch.einsum(\"...i,...i->...\", X, outcome_weights) + A * tau + intercept\n", + " ),\n", + " )\n", + "\n", + " def forward(self):\n", + " with pyro.poutine.mask(mask=self.include_prior):\n", + " intercept = self.sample_intercept()\n", + " outcome_weights = self.sample_outcome_weights()\n", + " propensity_weights = self.sample_propensity_weights()\n", + " tau = self.sample_treatment_weight()\n", + " x_loc, x_scale = self.sample_covariate_loc_scale()\n", + " with pyro.plate(\"plate\", self.N, dim=-1):\n", + " return self.generate_datum(x_loc, x_scale, propensity_weights, outcome_weights, tau, intercept)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we will condition on both treatment and confounders to estimate the causal effect of treatment on the outcome. We will use the following causal probabilistic program to do so:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "class ConditionedCausalGLM(CausalGLM):\n", + " def __init__(\n", + " self,\n", + " X: torch.Tensor,\n", + " A: torch.Tensor,\n", + " Y: torch.Tensor,\n", + " link_fn: Callable[..., dist.Distribution] = lambda mu: dist.Normal(mu, 1.0),\n", + " include_prior: bool = True,\n", + " prior_scale: Optional[float] = None,\n", + " ):\n", + " p = X.shape[1]\n", + " N = X.shape[0]\n", + " super().__init__(p, N, link_fn, include_prior, prior_scale)\n", + " self.X = X\n", + " self.A = A\n", + " self.Y = Y\n", + "\n", + " def forward(self):\n", + " with pyro.condition(data={\"X\": self.X, \"A\": self.A, \"Y\": self.Y}):\n", + " return super().forward()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "cluster_plate\n", + "\n", + "plate\n", + "\n", + "\n", + "\n", + "intercept\n", + "\n", + "intercept\n", + "\n", + "\n", + "\n", + "Y\n", + "\n", + "Y\n", + "\n", + "\n", + "\n", + "intercept->Y\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "outcome_weights\n", + "\n", + "outcome_weights\n", + "\n", + "\n", + "\n", + "outcome_weights->Y\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "propensity_weights\n", + "\n", + "propensity_weights\n", + "\n", + "\n", + "\n", + "A\n", + "\n", + "A\n", + "\n", + "\n", + "\n", + "propensity_weights->A\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "treatment_weight\n", + "\n", + "treatment_weight\n", + "\n", + "\n", + "\n", + "treatment_weight->Y\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "X\n", + "\n", + "X\n", + "\n", + "\n", + "\n", + "X->A\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "X->Y\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "A->Y\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "distribution_description_node\n", + "intercept ~ Normal\n", + "outcome_weights ~ Normal\n", + "propensity_weights ~ Normal\n", + "treatment_weight ~ Normal\n", + "X ~ Normal\n", + "A ~ Bernoulli\n", + "Y ~ Normal\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Visualize the model\n", + "pyro.render_model(\n", + " ConditionedCausalGLM(torch.zeros(1, 1), torch.zeros(1), torch.zeros(1)),\n", + " render_params=True, \n", + " render_distributions=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Generating data\n", + "\n", + "For evaluation, we generate `N_datasets` datasets, each with `N` samples. We compare vanilla estimates of the target functional with the double robust estimates of the target functional across the `N_sims` datasets. We use a similar data generating process as in Kennedy (2022)." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "class GroundTruthModel(CausalGLM):\n", + " def __init__(\n", + " self,\n", + " p: int,\n", + " N: int,\n", + " alpha: int,\n", + " beta: int,\n", + " link_fn: Callable[..., dist.Distribution] = lambda mu: dist.Normal(mu, 1.0),\n", + " treatment_weight: float = 0.0,\n", + " ):\n", + " super().__init__(p, N, link_fn)\n", + " self.alpha = alpha # sparsity of propensity weights\n", + " self.beta = beta # sparsity of outcome weights\n", + " self.treatment_weight = treatment_weight\n", + "\n", + " def sample_outcome_weights(self):\n", + " outcome_weights = 1 / math.sqrt(self.beta) * torch.ones(self.p)\n", + " outcome_weights[self.beta :] = 0.0\n", + " return outcome_weights\n", + "\n", + " def sample_propensity_weights(self):\n", + " propensity_weights = 1 / math.sqrt(self.alpha) * torch.ones(self.p)\n", + " propensity_weights[self.alpha :] = 0.0\n", + " return propensity_weights\n", + "\n", + " def sample_treatment_weight(self):\n", + " return torch.tensor(self.treatment_weight)\n", + "\n", + " def sample_intercept(self):\n", + " return torch.tensor(0.0)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "N_datasets = 50\n", + "simulated_datasets = []\n", + "\n", + "# Data configuration\n", + "p = 200\n", + "alpha = 50\n", + "beta = 50\n", + "N_train = 500\n", + "N_test = 500\n", + "treatment_weight = 1.0\n", + "\n", + "true_model = GroundTruthModel(p, N_train+N_test, alpha, beta, treatment_weight=treatment_weight)\n", + "prior_model = CausalGLM(p, N_train+N_test)\n", + "\n", + "# Generate data\n", + "D = Predictive(true_model, num_samples=N_datasets, return_sites=[\"X\", \"A\", \"Y\"], parallel=True)()\n", + "D_train = {k: v[:, :N_train] for k, v in D.items()}\n", + "D_test = {k: v[:, N_train:] for k, v in D.items()}\n", + "\n", + "# D_train : (N_datasets, N_train, p)\n", + "# D_test : (N_datasets, N_test, p)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Fit parameters via maximum likelihood" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0\n", + "tensor(143220.2969, grad_fn=)\n", + "tensor(142401.8281, grad_fn=)\n", + "tensor(142398.5000, grad_fn=)\n", + "tensor(142398.4844, grad_fn=)\n", + "tensor(142398.4844, grad_fn=)\n", + "tensor(142398.4844, grad_fn=)\n", + "tensor(142398.5000, grad_fn=)\n", + "tensor(142398.4844, grad_fn=)\n", + "tensor(142398.4844, grad_fn=)\n", + "tensor(142398.4844, grad_fn=)\n", + "tensor(142398.4844, grad_fn=)\n", + "tensor(142398.4844, grad_fn=)\n", + "tensor(142398.5312, grad_fn=)\n", + "tensor(142398.4844, grad_fn=)\n", + "tensor(142398.4844, grad_fn=)\n", + "tensor(142398.5000, grad_fn=)\n", + "tensor(142398.6094, grad_fn=)\n", + "tensor(142398.4844, grad_fn=)\n", + "tensor(142398.6094, grad_fn=)\n", + "tensor(142398.6562, grad_fn=)\n", + "1\n", + "tensor(142864.0625, grad_fn=)\n", + "tensor(142156.9531, grad_fn=)\n", + "tensor(142155.7812, grad_fn=)\n", + "tensor(142155.7812, grad_fn=)\n", + "tensor(142155.7812, grad_fn=)\n", + "tensor(142155.7812, grad_fn=)\n", + "tensor(142155.7812, grad_fn=)\n", + "tensor(142155.7812, grad_fn=)\n", + "tensor(142155.7812, grad_fn=)\n", + "tensor(142155.7812, grad_fn=)\n", + "tensor(142155.7812, grad_fn=)\n", + "tensor(142155.7812, grad_fn=)\n", + "tensor(142155.7812, grad_fn=)\n", + "tensor(142155.7969, grad_fn=)\n", + "tensor(142155.7812, grad_fn=)\n", + "tensor(142155.8281, grad_fn=)\n", + "tensor(142155.8125, grad_fn=)\n", + "tensor(142155.8906, grad_fn=)\n", + "tensor(142155.9688, grad_fn=)\n", + "tensor(142155.7969, grad_fn=)\n", + "2\n", + "tensor(143104.5469, grad_fn=)\n", + "tensor(142385.0469, grad_fn=)\n", + "tensor(142384.3594, grad_fn=)\n", + "tensor(142384.3594, grad_fn=)\n", + "tensor(142384.3594, grad_fn=)\n", + "tensor(142384.3594, grad_fn=)\n", + "tensor(142384.3594, grad_fn=)\n", + "tensor(142384.3594, grad_fn=)\n", + "tensor(142384.3594, grad_fn=)\n", + "tensor(142384.3750, grad_fn=)\n", + "tensor(142384.3906, grad_fn=)\n", + "tensor(142384.3594, grad_fn=)\n", + "tensor(142384.5156, grad_fn=)\n", + "tensor(142384.3750, grad_fn=)\n", + "tensor(142384.3906, grad_fn=)\n", + "tensor(142384.3594, grad_fn=)\n", + "tensor(142384.4062, grad_fn=)\n", + "tensor(142384.4375, grad_fn=)\n", + "tensor(142384.4062, grad_fn=)\n", + "tensor(142384.3906, grad_fn=)\n", + "3\n", + "tensor(142825.3125, grad_fn=)\n", + "tensor(142005.4062, grad_fn=)\n", + "tensor(142000.4844, grad_fn=)\n", + "tensor(142000.4688, grad_fn=)\n", + "tensor(142000.4688, grad_fn=)\n", + "tensor(142000.4688, grad_fn=)\n", + "tensor(142000.4688, grad_fn=)\n", + "tensor(142000.4688, grad_fn=)\n", + "tensor(142000.5000, grad_fn=)\n", + "tensor(142000.4844, grad_fn=)\n", + "tensor(142000.4844, grad_fn=)\n", + "tensor(142000.5000, grad_fn=)\n", + "tensor(142000.5156, grad_fn=)\n", + "tensor(142000.4844, grad_fn=)\n", + "tensor(142000.5156, grad_fn=)\n", + "tensor(142000.4844, grad_fn=)\n", + "tensor(142000.7188, grad_fn=)\n", + "tensor(142000.5938, grad_fn=)\n", + "tensor(142000.6094, grad_fn=)\n", + "tensor(142000.5000, grad_fn=)\n", + "4\n", + "tensor(143287.3281, grad_fn=)\n", + "tensor(142508.1094, grad_fn=)\n", + "tensor(142506.4531, grad_fn=)\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[13], line 20\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m j \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m2000\u001b[39m):\n\u001b[1;32m 19\u001b[0m adam\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m---> 20\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43melbo\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m j \u001b[38;5;241m%\u001b[39m \u001b[38;5;241m100\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 22\u001b[0m \u001b[38;5;28mprint\u001b[39m(loss)\n", + "File \u001b[0;32m~/opt/anaconda3/envs/chirho-dynamic/lib/python3.11/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "File \u001b[0;32m~/opt/anaconda3/envs/chirho-dynamic/lib/python3.11/site-packages/pyro/infer/elbo.py:25\u001b[0m, in \u001b[0;36mELBOModule.forward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m---> 25\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43melbo\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdifferentiable_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mguide\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/opt/anaconda3/envs/chirho-dynamic/lib/python3.11/site-packages/pyro/infer/trace_elbo.py:121\u001b[0m, in \u001b[0;36mTrace_ELBO.differentiable_loss\u001b[0;34m(self, model, guide, *args, **kwargs)\u001b[0m\n\u001b[1;32m 119\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0.0\u001b[39m\n\u001b[1;32m 120\u001b[0m surrogate_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0.0\u001b[39m\n\u001b[0;32m--> 121\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m model_trace, guide_trace \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_traces(model, guide, args, kwargs):\n\u001b[1;32m 122\u001b[0m loss_particle, surrogate_loss_particle \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_differentiable_loss_particle(\n\u001b[1;32m 123\u001b[0m model_trace, guide_trace\n\u001b[1;32m 124\u001b[0m )\n\u001b[1;32m 125\u001b[0m surrogate_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m surrogate_loss_particle \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_particles\n", + "File \u001b[0;32m~/opt/anaconda3/envs/chirho-dynamic/lib/python3.11/site-packages/pyro/infer/elbo.py:237\u001b[0m, in \u001b[0;36mELBO._get_traces\u001b[0;34m(self, model, guide, args, kwargs)\u001b[0m\n\u001b[1;32m 235\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 236\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_particles):\n\u001b[0;32m--> 237\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_trace\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mguide\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/opt/anaconda3/envs/chirho-dynamic/lib/python3.11/site-packages/pyro/infer/trace_elbo.py:57\u001b[0m, in \u001b[0;36mTrace_ELBO._get_trace\u001b[0;34m(self, model, guide, args, kwargs)\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_get_trace\u001b[39m(\u001b[38;5;28mself\u001b[39m, model, guide, args, kwargs):\n\u001b[1;32m 53\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 54\u001b[0m \u001b[38;5;124;03m Returns a single trace from the guide, and the model that is run\u001b[39;00m\n\u001b[1;32m 55\u001b[0m \u001b[38;5;124;03m against it.\u001b[39;00m\n\u001b[1;32m 56\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 57\u001b[0m model_trace, guide_trace \u001b[38;5;241m=\u001b[39m \u001b[43mget_importance_trace\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 58\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mflat\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax_plate_nesting\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mguide\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 59\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 60\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_validation_enabled():\n\u001b[1;32m 61\u001b[0m check_if_enumerated(guide_trace)\n", + "File \u001b[0;32m~/opt/anaconda3/envs/chirho-dynamic/lib/python3.11/site-packages/pyro/infer/enum.py:75\u001b[0m, in \u001b[0;36mget_importance_trace\u001b[0;34m(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)\u001b[0m\n\u001b[1;32m 72\u001b[0m guide_trace \u001b[38;5;241m=\u001b[39m prune_subsample_sites(guide_trace)\n\u001b[1;32m 73\u001b[0m model_trace \u001b[38;5;241m=\u001b[39m prune_subsample_sites(model_trace)\n\u001b[0;32m---> 75\u001b[0m \u001b[43mmodel_trace\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_log_prob\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 76\u001b[0m guide_trace\u001b[38;5;241m.\u001b[39mcompute_score_parts()\n\u001b[1;32m 77\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_validation_enabled():\n", + "File \u001b[0;32m~/opt/anaconda3/envs/chirho-dynamic/lib/python3.11/site-packages/pyro/poutine/trace_struct.py:230\u001b[0m, in \u001b[0;36mTrace.compute_log_prob\u001b[0;34m(self, site_filter)\u001b[0m\n\u001b[1;32m 228\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlog_prob\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m site:\n\u001b[1;32m 229\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 230\u001b[0m log_p \u001b[38;5;241m=\u001b[39m \u001b[43msite\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mfn\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlog_prob\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 231\u001b[0m \u001b[43m \u001b[49m\u001b[43msite\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mvalue\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43msite\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43margs\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43msite\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mkwargs\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[1;32m 232\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 233\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 234\u001b[0m _, exc_value, traceback \u001b[38;5;241m=\u001b[39m sys\u001b[38;5;241m.\u001b[39mexc_info()\n", + "File \u001b[0;32m~/opt/anaconda3/envs/chirho-dynamic/lib/python3.11/site-packages/torch/distributions/independent.py:99\u001b[0m, in \u001b[0;36mIndependent.log_prob\u001b[0;34m(self, value)\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mlog_prob\u001b[39m(\u001b[38;5;28mself\u001b[39m, value):\n\u001b[0;32m---> 99\u001b[0m log_prob \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbase_dist\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlog_prob\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 100\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _sum_rightmost(log_prob, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreinterpreted_batch_ndims)\n", + "File \u001b[0;32m~/opt/anaconda3/envs/chirho-dynamic/lib/python3.11/site-packages/torch/distributions/normal.py:83\u001b[0m, in \u001b[0;36mNormal.log_prob\u001b[0;34m(self, value)\u001b[0m\n\u001b[1;32m 81\u001b[0m var \u001b[38;5;241m=\u001b[39m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscale \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m \u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 82\u001b[0m log_scale \u001b[38;5;241m=\u001b[39m math\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscale) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscale, Real) \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscale\u001b[38;5;241m.\u001b[39mlog()\n\u001b[0;32m---> 83\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;241m-\u001b[39m(\u001b[43m(\u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloc\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m) \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m var) \u001b[38;5;241m-\u001b[39m log_scale \u001b[38;5;241m-\u001b[39m math\u001b[38;5;241m.\u001b[39mlog(math\u001b[38;5;241m.\u001b[39msqrt(\u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m math\u001b[38;5;241m.\u001b[39mpi))\n", + "File \u001b[0;32m~/opt/anaconda3/envs/chirho-dynamic/lib/python3.11/site-packages/torch/_tensor.py:34\u001b[0m, in \u001b[0;36m_handle_torch_function_and_wrap_type_error_to_not_implemented..wrapped\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_handle_torch_function_and_wrap_type_error_to_not_implemented\u001b[39m(f):\n\u001b[1;32m 32\u001b[0m assigned \u001b[38;5;241m=\u001b[39m functools\u001b[38;5;241m.\u001b[39mWRAPPER_ASSIGNMENTS\n\u001b[0;32m---> 34\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(f, assigned\u001b[38;5;241m=\u001b[39massigned)\n\u001b[1;32m 35\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapped\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 36\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 37\u001b[0m \u001b[38;5;66;03m# See https://github.com/pytorch/pytorch/issues/75462\u001b[39;00m\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function(args):\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "fitted_params = []\n", + "for i in range(N_datasets):\n", + " print(i)\n", + "\n", + " # Fit model using maximum likelihood\n", + " conditioned_model = ConditionedCausalGLM(\n", + " X=D_train[\"X\"][i], A=D_train[\"A\"][i], Y=D_train[\"Y\"][i]\n", + " )\n", + " \n", + " guide_train = pyro.infer.autoguide.AutoDelta(conditioned_model)\n", + " elbo = pyro.infer.Trace_ELBO()(conditioned_model, guide_train)\n", + "\n", + " # initialize parameters\n", + " elbo()\n", + " adam = torch.optim.Adam(elbo.parameters(), lr=0.03)\n", + "\n", + " # Do gradient steps\n", + " for j in range(2000):\n", + " adam.zero_grad()\n", + " loss = elbo()\n", + " if j % 100 == 0:\n", + " print(loss)\n", + " loss.backward()\n", + " adam.step()\n", + "\n", + " theta_hat = {\n", + " k: v.clone().detach().requires_grad_(True) for k, v in guide_train().items()\n", + " }\n", + " fitted_params.append(theta_hat)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Causal Query: Sample Average treatment effect (ATE)\n", + "\n", + "The average treatment effect summarizes, on average, how much the treatment changes the response, $ATE = \\mathbb{E}[Y|do(A=1)] - \\mathbb{E}[Y|do(A=0)]$. The `do` notation indicates that the expectations are taken according to *intervened* versions of the model, with $A$ set to a particular value. Note from our [tutorial](tutorial_i.ipynb) that this is different from conditioning on $A$ in the original `causal_model`, which assumes $X$ and $T$ are dependent.\n", + "\n", + "\n", + "To implement this query in ChiRho, we define the `SATEFunctional` class which take in a `model` and `guide` and returns the average treatment effect by simulating from the posterior predictive distribution of the model and guide." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Defining the target functional" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class SATEFunctional(torch.nn.Module):\n", + " def __init__(self, model: Callable, *, num_monte_carlo: int = 100):\n", + " super().__init__()\n", + " self.model = model\n", + " self.num_monte_carlo = num_monte_carlo\n", + " \n", + " def forward(self, *args, **kwargs):\n", + " with MultiWorldCounterfactual():\n", + " with do(actions=dict(A=(torch.tensor(0.0), torch.tensor(1.0)))):\n", + " Ys = self.model(*args, **kwargs)\n", + " Y0 = gather(Ys, IndexSet(A={1}), event_dim=0)\n", + " Y1 = gather(Ys, IndexSet(A={2}), event_dim=0)\n", + " sate = (Y1 - Y0).mean(dim=-1, keepdim=True).squeeze()\n", + " return pyro.deterministic(\"SATE\", sate)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "SATE = SATEFunctional(true_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Closed form doubly robust correction\n", + "\n", + "For the average treatment effect functional, there exists a closed-form analytical formula for the doubly robust correction. This formula is derived in Kennedy (2022) and is implemented below:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Any\n", + "from chirho.robust.ops import Functional, Point, P, S, T\n", + "\n", + "def SATECausalGLM_analytic_influence(functional: Functional[P, S], \n", + " point: Point[T], \n", + " pointwise_influence: bool = True,\n", + " **kwargs) -> Functional[P, S]:\n", + " # assert isinstance(functional, SATEFunctional)\n", + " def new_functional(model: Callable[P, Any]) -> Callable[P, S]:\n", + " assert isinstance(model.model, CausalGLM)\n", + " theta = dict(model.guide.named_parameters())\n", + " def new_model(*args, **kwargs):\n", + " X = point[\"X\"]\n", + " A = point[\"A\"]\n", + " Y = point[\"Y\"]\n", + " \n", + " pi_X = torch.sigmoid(torch.einsum(\"...i,...i->...\", X, theta[\"propensity_weights_param\"]))\n", + " mu_X = (\n", + " torch.einsum(\"...i,...i->...\", X, theta[\"outcome_weights_param\"])\n", + " + A * theta[\"treatment_weight_param\"]\n", + " + theta[\"intercept_param\"]\n", + " )\n", + " analytic_eif_at_pts = (A / pi_X - (1 - A) / (1 - pi_X)) * (Y - mu_X)\n", + " if pointwise_influence:\n", + " return analytic_eif_at_pts\n", + " else:\n", + " return analytic_eif_at_pts.mean()\n", + "\n", + "\n", + " return new_model\n", + " return new_functional\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# # Closed form expression\n", + "# def closed_form_doubly_robust_ate_correction(X_test, theta) -> Tuple[torch.Tensor, torch.Tensor]:\n", + "# X = X_test[\"X\"]\n", + "# A = X_test[\"A\"]\n", + "# Y = X_test[\"Y\"]\n", + "# pi_X = torch.sigmoid(X.mv(theta[\"propensity_weights\"]))\n", + "# mu_X = (\n", + "# X.mv(theta[\"outcome_weights\"])\n", + "# + A * theta[\"treatment_weight\"]\n", + "# + theta[\"intercept\"]\n", + "# )\n", + "# analytic_eif_at_test_pts = (A / pi_X - (1 - A) / (1 - pi_X)) * (Y - mu_X)\n", + "# analytic_correction = analytic_eif_at_test_pts.mean()\n", + "# return analytic_correction, analytic_eif_at_test_pts" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Computing automated doubly robust correction via Monte Carlo\n", + "\n", + "While the doubly robust correction term is known in closed-form for the average treatment effect functional, our `one_step_correction` function in `ChiRho` works for a wide class of other functionals. We focus on the average treatment effect functional here so that we have a ground truth to compare `one_step_correction` against." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import tracemalloc\n", + "\n", + "tracemalloc.start()\n", + "\n", + "# Helper class to create a trivial guide that returns the maximum likelihood estimate\n", + "class MLEGuide(torch.nn.Module):\n", + " def __init__(self, mle_est: ParamDict):\n", + " super().__init__()\n", + " self.names = list(mle_est.keys())\n", + " for name, value in mle_est.items():\n", + " setattr(self, name + \"_param\", torch.nn.Parameter(value))\n", + "\n", + " def forward(self, *args, **kwargs):\n", + " for name in self.names:\n", + " value = getattr(self, name + \"_param\")\n", + " pyro.sample(\n", + " name, pyro.distributions.Delta(value, event_dim=len(value.shape))\n", + " )\n", + "\n", + "# Compute doubly robust ATE estimates using both the automated and closed form expressions\n", + "# estimators = {\"tmle\": tmle, \"one_step\": one_step_corrected_estimator}\n", + "estimators = {\"one_step\": one_step_corrected_estimator}\n", + "estimator_kwargs = {\"tmle\": {\"learning_rate\":5e-5,\n", + " \"n_grad_steps\":500,\n", + " \"n_tmle_steps\":1,\n", + " \"num_nmc_samples\":1000,\n", + " \"num_grad_samples\":N_test}, \"one_step\": {}}\n", + "# influences = {\"analytic\": SATECausalGLM_analytic_influence, \"monte_carlo\": influence_fn}\n", + "influences = {\"analytic\": SATECausalGLM_analytic_influence}\n", + "\n", + "estimates = {f\"{influence}-{estimator}\": torch.zeros(N_datasets) for influence in influences.keys() for estimator in estimators.keys()}\n", + "estimates[\"plug-in-mle\"] = torch.zeros(N_datasets)\n", + "estimates[\"plug-in-prior\"] = torch.zeros(N_datasets)\n", + "estimates[\"plug-in-truth\"] = torch.zeros(N_datasets)\n", + "\n", + "functional = functools.partial(SATEFunctional, num_monte_carlo=10000)\n", + "\n", + "for i in range(N_datasets):\n", + " print(\"plug-in-prior\", i)\n", + " estimates[\"plug-in-prior\"][i] = functional(prior_model)().item()\n", + "\n", + " print(\"plug-in-truth\", i)\n", + " estimates[\"plug-in-truth\"][i] = functional(true_model)().item()\n", + "\n", + " # D_test = simulated_datasets[i][1]\n", + " theta_hat = fitted_params[i]\n", + " mle_guide = MLEGuide(theta_hat)\n", + "\n", + " model = PredictiveModel(CausalGLM(p, N_test), mle_guide)\n", + " \n", + " print(\"plug-in-mle\", i)\n", + " estimates[\"plug-in-mle\"][i] = functional(model)().detach().item()\n", + "\n", + " for estimator_str, estimator in estimators.items():\n", + " for influence_str, influence in influences.items():\n", + " if estimator_str == \"tmle\" and influence_str == \"monte_carlo\":\n", + " continue\n", + "\n", + " print(estimator_str, influence_str, i)\n", + " \n", + " estimate = estimator(\n", + " functional, \n", + " {\"X\": D_test[\"X\"][i], \"A\": D_test[\"A\"][i], \"Y\": D_test[\"Y\"][i]},\n", + " num_samples_outer=max(10000, 100 * p), \n", + " num_samples_inner=1,\n", + " influence_estimator=influence,\n", + " **estimator_kwargs[estimator_str]\n", + " )(model)()\n", + "\n", + " estimates[f\"{influence_str}-{estimator_str}\"][i] = estimate.detach().item()\n", + "\n", + " print(tracemalloc.get_traced_memory())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "results = pd.DataFrame(estimates)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The true treatment effect is 0, so a mean estimate closer to zero is better\n", + "results.describe().round(2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize the results\n", + "fig, ax = plt.subplots()\n", + "\n", + "for col in results.columns:\n", + " sns.kdeplot(results[col], ax=ax, label=col)\n", + "\n", + "ax.axvline(treatment_weight, color=\"black\", label=\"True ATE\", linestyle=\"--\")\n", + "ax.set_yticks([])\n", + "sns.despine()\n", + "ax.legend(loc=\"upper right\")\n", + "ax.set_xlabel(\"ATE Estimate\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# plt.scatter(\n", + "# results['automated_monte_carlo_correction'],\n", + "# results['analytic_correction'],\n", + "# )\n", + "# plt.plot(np.linspace(-.2, .5), np.linspace(-.2, .5), color=\"black\", linestyle=\"dashed\")\n", + "# plt.xlabel(\"DR-Monte Carlo\")\n", + "# plt.ylabel(\"DR-Analytic\")\n", + "# sns.despine()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## References\n", + "\n", + "Kennedy, Edward. \"Towards optimal doubly robust estimation of heterogeneous causal effects\", 2022. https://arxiv.org/abs/2004.14497." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "basis", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/setup.py b/setup.py index 47c6dcd3d..0b425f72a 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,7 @@ ] DYNAMICAL_REQUIRE = ["torchdiffeq"] +ROBUST_REQUIRE = ["torchopt"] setup( name="chirho", @@ -45,8 +46,9 @@ ], extras_require={ "dynamical": DYNAMICAL_REQUIRE, + "robust": ROBUST_REQUIRE, "extras": EXTRAS_REQUIRE, - "test": EXTRAS_REQUIRE + DYNAMICAL_REQUIRE + "test": EXTRAS_REQUIRE + DYNAMICAL_REQUIRE + ROBUST_REQUIRE + [ "pytest", "pytest-cov", diff --git a/tests/robust/test_handlers.py b/tests/robust/test_handlers.py index e43015282..9cd55e749 100644 --- a/tests/robust/test_handlers.py +++ b/tests/robust/test_handlers.py @@ -1,3 +1,4 @@ +import copy import functools from typing import Callable, List, Mapping, Optional, Set, Tuple, TypeVar @@ -6,7 +7,7 @@ import torch from typing_extensions import ParamSpec -from chirho.robust.handlers.estimators import one_step_corrected_estimator +from chirho.robust.handlers.estimators import one_step_corrected_estimator, tmle from chirho.robust.handlers.predictive import PredictiveFunctional, PredictiveModel from .robust_fixtures import SimpleGuide, SimpleModel @@ -42,7 +43,7 @@ @pytest.mark.parametrize("num_samples_outer,num_samples_inner", [(10, None), (10, 100)]) @pytest.mark.parametrize("cg_iters", [None, 1, 10]) @pytest.mark.parametrize("num_predictive_samples", [1, 5]) -@pytest.mark.parametrize("estimation_method", [one_step_corrected_estimator]) +@pytest.mark.parametrize("estimation_method", [one_step_corrected_estimator, tmle]) def test_estimator_smoke( model, guide, @@ -66,6 +67,20 @@ def test_estimator_smoke( )().items() } + predictive_model = PredictiveModel(model, guide) + + prev_params = copy.deepcopy(dict(predictive_model.named_parameters())) + + if estimation_method == tmle: + estimator_kwargs = { + "n_tmle_steps": 1, + "n_grad_steps": 2, + "num_nmc_samples": 10, + "num_grad_samples": 10, + } + else: + estimator_kwargs = {} + estimator = estimation_method( functools.partial(PredictiveFunctional, num_samples=num_predictive_samples), test_datum, @@ -73,7 +88,8 @@ def test_estimator_smoke( num_samples_outer=num_samples_outer, num_samples_inner=num_samples_inner, cg_iters=cg_iters, - )(PredictiveModel(model, guide)) + **estimator_kwargs, + )(predictive_model) estimate_on_test: Mapping[str, torch.Tensor] = estimator() assert len(estimate_on_test) > 0 @@ -83,3 +99,8 @@ def test_estimator_smoke( assert not torch.isclose( v, torch.zeros_like(v) ).all(), f"{estimation_method} estimator for {k} was zero" + + # Assert estimator doesn't have side effects on model parameters. + new_params = dict(predictive_model.named_parameters()) + for k, v in prev_params.items(): + assert torch.allclose(v, new_params[k]), f"{k} was updated"