Skip to content

Commit

Permalink
Merge pull request #109 from TARGENE/cv_check_and_doc
Browse files Browse the repository at this point in the history
Cv check and doc
  • Loading branch information
olivierlabayle authored Apr 29, 2024
2 parents 3d27968 + d3ba349 commit 5d4a8e9
Show file tree
Hide file tree
Showing 17 changed files with 623 additions and 124 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TMLE"
uuid = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf"
authors = ["Olivier Labayle"]
version = "0.16.0"
version = "0.16.1"

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ makedocs(;
joinpath("examples", "super_learning.md"),
joinpath("examples", "double_robustness.md")
],
"Estimators' Cheat Sheet" => "estimators_cheatsheet.md",
"Resources" => "resources.md",
"API Reference" => "api.md"
],
Expand Down
Binary file added docs/src/assets/sample_splitting.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
304 changes: 304 additions & 0 deletions docs/src/estimators_cheatsheet.md

Large diffs are not rendered by default.

40 changes: 36 additions & 4 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ CurrentModule = TMLE

## Overview

TMLE.jl is a Julia implementation of the Targeted Minimum Loss-Based Estimation ([TMLE](https://link.springer.com/book/10.1007/978-1-4419-9782-1)) framework. If you are interested in efficient and unbiased estimation of causal effects, you are in the right place. Since TMLE uses machine-learning methods to estimate nuisance estimands, the present package is based upon [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/).
TMLE.jl is a Julia implementation of the Targeted Minimum Loss-Based Estimation ([TMLE](https://link.springer.com/book/10.1007/978-1-4419-9782-1)) framework. If you are interested in leveraging the power of modern machine-learning methods while preserving interpretability and statistical inference guarantees, you are in the right place. TMLE.jl is compatible with any [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) compliant algorithm and any dataset respecting the [Tables](https://tables.juliadata.org/stable/) interface.

## Installation

Expand All @@ -20,7 +20,7 @@ Pkg> add TMLE

To run an estimation procedure, we need 3 ingredients:

1. A dataset: here a simulation dataset.
### 1. A dataset: here a simulation dataset

For illustration, assume we know the actual data generating process is as follows:

Expand Down Expand Up @@ -52,7 +52,7 @@ dataset = (Y=Y, T=categorical(T), W=W)
nothing # hide
```

2. A quantity of interest: here the Average Treatment Effect (ATE).
### 2. A quantity of interest: here the Average Treatment Effect (ATE)

The Average Treatment Effect of ``T`` on ``Y`` confounded by ``W`` is defined as:

Expand All @@ -64,7 +64,7 @@ The Average Treatment Effect of ``T`` on ``Y`` confounded by ``W`` is defined as
)
```

3. An estimator: here a Targeted Maximum Likelihood Estimator (TMLE).
### 3. An estimator: here a Targeted Maximum Likelihood Estimator (TMLE)

```@example quick-start
tmle = TMLEE()
Expand All @@ -79,3 +79,35 @@ using Test # hide
@test pvalue(OneSampleTTest(result, 2.5)) > 0.05 # hide
nothing # hide
```

## Scope and Distinguishing Features

The goal of this package is to provide an entry point for semi-parametric asymptotic unbiased and efficient estimation in Julia. The two main general estimators that are known to achieve these properties are the One-Step estimator and the Targeted Maximum-Likelihood estimator. Most of the current effort has been centered around estimands that are composite of the counterfactual mean.

Distinguishing Features:

- Estimands: Counterfactual Mean, Average Treatment Effect, Interactions, Any composition thereof
- Estimators: TMLE, One-Step, in both canonical and cross-validated versions.
- Machine-Learning: Any [MLJ](https://alan-turing-institute.github.io/MLJ.jl/stable/) compatible model
- Dataset: Any dataset respecting the [Tables](https://tables.juliadata.org/stable/) interface (e.g. [DataFrames.jl](https://dataframes.juliadata.org/stable/))
- Factorial Treatment Variables:
- Multiple treatments
- Categorical treatment values

## Citing TMLE.jl

If you use TMLE.jl for your own work and would like to cite us, here are the BibTeX and APA formats:

- BibTeX

```bibtex
@software{Labayle_TMLE_jl,
author = {Labayle, Olivier and Beentjes, Sjoerd and Khamseh, Ava and Ponting, Chris},
title = {{TMLE.jl}},
url = {https://github.com/olivierlabayle/TMLE.jl}
}
```

- APA

Labayle, O., Beentjes, S., Khamseh, A., & Ponting, C. TMLE.jl [Computer software]. https://github.com/olivierlabayle/TMLE.jl
8 changes: 7 additions & 1 deletion docs/src/resources.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,15 @@ These are two very clear introductions to causal inference and semi-parametric e
- [Introduction to Modern Causal Inference](https://alejandroschuler.github.io/mci/) (Alejandro Schuler, Mark J. van der Laan).
- [A Ride in Targeted Learning Territory](https://achambaz.github.io/tlride/) (David Benkeser, Antoine Chambaz).

## Text Books
## Youtube

- [Targeted Learning Webinar Series](https://youtube.com/playlist?list=PLy_CaFomwGGGH10tbq9zSyfHVrdklMaLe&si=BfJZ2fvDtGUZwELy)
- [TL Briefs](https://youtube.com/playlist?list=PLy_CaFomwGGFMxFtf4gkmC70dP9J6Q3Wa&si=aBZUnjJtOidIjhwR)

## Books and Lecture Notes

- [Targeted Learning](https://link.springer.com/book/10.1007/978-1-4419-9782-1) (Mark J. van der Laan, Sherri Rose).
- [STATS 361: Causal Inference](https://web.stanford.edu/~swager/stats361.pdf)

## Journal articles

Expand Down
88 changes: 68 additions & 20 deletions docs/src/user_guide/estimation.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ CurrentModule = TMLE

# Estimation

## Estimating a single Estimand
## Constructing and Using Estimators

```@setup estimation
using Random
Expand Down Expand Up @@ -51,11 +51,12 @@ scm = SCM([
)
```

Once a statistical estimand has been defined, we can proceed with estimation. At the moment, we provide 3 main types of estimators:
Once a statistical estimand has been defined, we can proceed with estimation. There are two semi-parametric efficient estimators in TMLE.jl:

- Targeted Maximum Likelihood Estimator (`TMLEE`)
- One-Step Estimator (`OSE`)
- Naive Plugin Estimator (`NAIVE`)
- The Targeted Maximum-Likelihood Estimator (`TMLEE`)
- The One-Step Estimator (`OSE`)

While they have similar asymptotic properties, their finite sample performance may be different. They also have a very distinguishing feature, the TMLE is a plugin estimator, which means it respects the natural bounds of the estimand of interest. In contrast, the OSE may in theory report values outside these bounds. In practice, this is not often the case and the estimand of interest may not impose any restriction on its domain.

Drawing from the example dataset and `SCM` from the Walk Through section, we can estimate the ATE for `T₁`. Let's use TMLE:

Expand All @@ -72,27 +73,25 @@ result₁
nothing # hide
```

We see that both models corresponding to variables `Y` and `T₁` were fitted in the process but that the model for `T₂` was not because it was not necessary to estimate this estimand.

The `cache` contains estimates for the nuisance functions that were necessary to estimate the ATE. For instance, we can see what is the value of ``\epsilon`` corresponding to the clever covariate.
The `cache` (see below) contains estimates for the nuisance functions that were necessary to estimate the ATE. For instance, we can see what is the value of ``\epsilon`` corresponding to the clever covariate.

```@example estimation
ϵ = last_fluctuation_epsilon(cache)
```

The `result₁` structure corresponds to the estimation result and should report 3 main elements:
The `result₁` structure corresponds to the estimation result and will display the result of a T-Test including:

- A point estimate.
- A 95% confidence interval.
- A p-value (Corresponding to the test that the estimand is different than 0).

This is only summary statistics but since both the TMLE and OSE are asymptotically linear estimators, standard Z/T tests from [HypothesisTests.jl](https://juliastats.org/HypothesisTests.jl/stable/) can be performed.
Both the TMLE and OSE are asymptotically linear estimators, standard Z/T tests from [HypothesisTests.jl](https://juliastats.org/HypothesisTests.jl/stable/) can be performed and `confint` and `pvalue` methods used.

```@example estimation
tmle_test_result₁ = OneSampleTTest(result₁)
tmle_test_result₁ = pvalue(OneSampleTTest(result₁))
```

We could now get an interest in the Average Treatment Effect of `T₂` that we will estimate with an `OSE`:
Let us now turn to the Average Treatment Effect of `T₂`, we will estimate it with a `OSE`:

```@example estimation
Ψ₂ = ATE(
Expand All @@ -109,24 +108,73 @@ nothing # hide

Again, required nuisance functions are fitted and stored in the cache.

## CV-Estimation
## Specifying Models

Both TMLE and OSE can be used with sample-splitting, which, for an additional computational cost, further reduces the assumptions we need to make regarding our data generating process ([see here](https://arxiv.org/abs/2203.06469)). Note that this sample-splitting procedure should not be confused with the sample-splitting happening in Super Learning. Using both CV-TMLE and Super-Learning will result in two nested sample-splitting loops.
By default, TMLE.jl uses generalized linear models for the estimation of relevant and nuisance factors such as the outcome mean and the propensity score. However, this is not the recommended usage since the estimators' performance is closely related to how well we can estimate these factors. More sophisticated models can be provided using the `models` keyword argument of each estimator which is essentially a `NamedTuple` mapping variables' names to their respective model.

To leverage sample-splitting, simply specify a `resampling` strategy when building an estimator:
Rather than specifying a specific model for each variable it may be easier to override the default models using the `default_models` function:

For example one can override all default models with XGBoost models from `MLJXGBoostInterface`:

```@example estimation
cvtmle = TMLEE(resampling=CV())
cvresult₁, _ = cvtmle(Ψ₁, dataset);
using MLJXGBoostInterface
xgboost_regressor = XGBoostRegressor()
xgboost_classifier = XGBoostClassifier()
models = default_models(
Q_binary=xgboost_classifier,
Q_continuous=xgboost_regressor,
G=xgboost_classifier
)
tmle_gboost = TMLEE(models=models)
```

Similarly, one could build CV-OSE:
The advantage of using `default_models` is that it will automatically prepend each model with a [ContinuousEncoder](https://alan-turing-institute.github.io/MLJ.jl/dev/transformers/#MLJModels.ContinuousEncoder) to make sure the correct types are passed to the downstream models.

Super Learning ([Stack](https://alan-turing-institute.github.io/MLJ.jl/dev/model_stacking/#Model-Stacking)) as well as variable specific models can be defined as well. Here is a more customized version:

```@example estimation
lr = LogisticClassifier(lambda=0.)
stack_binary = Stack(
metalearner=lr,
xgboost=xgboost_classifier,
lr=lr
)
models = (
T₁ = with_encoder(xgboost_classifier), # T₁ with XGBoost prepended with a Continuous Encoder
default_models( # For all other variables use the following defaults
Q_binary=stack_binary, # A Super Learner
Q_continuous=xgboost_regressor, # An XGBoost
# Unspecified G defaults to Logistic Regression
)...
)
tmle_custom = TMLEE(models=models)
```

Notice that `with_encoder` is simply a shorthand to construct a pipeline with a `ContinuousEncoder` and that the resulting `models` is simply a `NamedTuple`.

## CV-Estimation

Canonical TMLE/OSE are essentially using the dataset twice, once for the estimation of the nuisance functions and once for the estimation of the parameter of interest. This means that there is a risk of over-fitting and residual bias ([see here](https://arxiv.org/abs/2203.06469) for some discussion). One way to address this limitation is to use a technique called sample-splitting / cross-validating. In order to activate the sample-splitting mode, simply provide a `MLJ.ResamplingStrategy` using the `resampling` keyword argument:

```@example estimation
TMLEE(resampling=StratifiedCV());
```

or

```julia
cvose = OSE(resampling=CV(nfolds=3))
OSE(resampling=StratifiedCV(nfolds=3));
```

## Caching model fits
There are some practical considerations

- Choice of `resampling` Strategy: The theory behind sample-splitting requires the nuisance functions to be sufficiently well estimated on **each and every** fold. A practical aspect of it is that each fold should contain a sample representative of the dataset. In particular, when the treatment and outcome variables are categorical it is important to make sure the proportions are preserved. This is typically done using `StratifiedCV`.
- Computational Complexity: Sample-splitting results in ``K`` fits of the nuisance functions, drastically increasing computational complexity. In particular, if the nuisance functions are estimated using (P-fold) Super-Learning, this will result in two nested cross-validation loops and ``K \times P`` fits.
- Caching of Nuisance Functions: Because the `resampling` strategy typically needs to preserve the outcome and treatment proportions, very little reuse of cached models is possible (see [Caching Models](@ref)).

## Caching Models

Let's now see how the `cache` can be reused with a new estimand, say the Total Average Treatment Effect of both `T₁` and `T₂`.

Expand Down
26 changes: 17 additions & 9 deletions src/counterfactual_mean_based/estimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,17 @@ function (tmle::TMLEE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict()
machine_cache=tmle.machine_cache
)
# Estimation results after TMLE
IC, Ψ̂ = gradient_and_estimate(Ψ, targeted_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound)
IC, Ψ̂ = gradient_and_estimate(tmle, Ψ, targeted_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound)
σ̂ = std(IC)
n = size(IC, 1)
verbosity >= 1 && @info "Done."
# update!(cache, relevant_factors, targeted_factors_estimate)
return TMLEstimate(Ψ, Ψ̂, σ̂, n, IC), cache
end

gradient_and_estimate(::TMLEE, Ψ, factors, dataset; ps_lowerbound=1e-8) =
gradient_and_plugin_estimate(Ψ, factors, dataset; ps_lowerbound=ps_lowerbound)

#####################################################################
### OSE ###
#####################################################################
Expand Down Expand Up @@ -267,14 +270,14 @@ ose = OSE()
OSE(;models=default_models(), resampling=nothing, ps_lowerbound=1e-8, machine_cache=false) =
OSE(models, resampling, ps_lowerbound, machine_cache)

function (estimator::OSE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), verbosity=1)
function (ose::OSE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), verbosity=1)
# Check the estimand against the dataset
check_treatment_levels(Ψ, dataset)
# Initial fit of the SCM's relevant factors
initial_factors = get_relevant_factors(Ψ)
nomissing_dataset = nomissing(dataset, variables(initial_factors))
initial_factors_dataset = choose_initial_dataset(dataset, nomissing_dataset, estimator.resampling)
initial_factors_estimator = CMRelevantFactorsEstimator(estimator.resampling, estimator.models)
initial_factors_dataset = choose_initial_dataset(dataset, nomissing_dataset, ose.resampling)
initial_factors_estimator = CMRelevantFactorsEstimator(ose.resampling, ose.models)
initial_factors_estimate = initial_factors_estimator(
initial_factors,
initial_factors_dataset;
Expand All @@ -283,16 +286,21 @@ function (estimator::OSE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dic
)
# Get propensity score truncation threshold
n = nrows(nomissing_dataset)
ps_lowerbound = ps_lower_bound(n, estimator.ps_lowerbound)
ps_lowerbound = ps_lower_bound(n, ose.ps_lowerbound)

# Gradient and estimate
IC, Ψ̂ = gradient_and_estimate(Ψ, initial_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound)
IC_mean = mean(IC)
IC .-= IC_mean
IC, Ψ̂ = gradient_and_estimate(ose, Ψ, initial_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound)
σ̂ = std(IC)
n = size(IC, 1)
verbosity >= 1 && @info "Done."
return OSEstimate(Ψ, Ψ̂ + IC_mean, σ̂, n, IC), cache
return OSEstimate(Ψ, Ψ̂, σ̂, n, IC), cache
end

function gradient_and_estimate(::OSE, Ψ, factors, dataset; ps_lowerbound=1e-8)
IC, Ψ̂ = gradient_and_plugin_estimate(Ψ, factors, dataset; ps_lowerbound=ps_lowerbound)
IC_mean = mean(IC)
IC .-= IC_mean
return IC, Ψ̂ + IC_mean
end

#####################################################################
Expand Down
9 changes: 3 additions & 6 deletions src/counterfactual_mean_based/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@ function counterfactual_aggregate(Ψ::StatisticalCMCompositeEstimand, Q, dataset
return ctf_agg
end

compute_estimate(ctf_aggregate, ::Nothing) = mean(ctf_aggregate)

compute_estimate(ctf_aggregate, train_validation_indices) =
mean(compute_estimate(ctf_aggregate[val_indices], nothing) for (_, val_indices) in train_validation_indices)
plugin_estimate(ctf_aggregate) = mean(ctf_aggregate)


"""
Expand All @@ -53,11 +50,11 @@ function ∇YX(Ψ::StatisticalCMCompositeEstimand, Q, G, dataset; ps_lowerbound=
end


function gradient_and_estimate::StatisticalCMCompositeEstimand, factors, dataset; ps_lowerbound=1e-8)
function gradient_and_plugin_estimate::StatisticalCMCompositeEstimand, factors, dataset; ps_lowerbound=1e-8)
Q = factors.outcome_mean
G = factors.propensity_score
ctf_agg = counterfactual_aggregate(Ψ, Q, dataset)
Ψ̂ = compute_estimate(ctf_agg, train_validation_indices_from_factors(factors))
Ψ̂ = plugin_estimate(ctf_agg)
IC = ∇YX(Ψ, Q, G, dataset; ps_lowerbound = ps_lowerbound) .+ ∇W(ctf_agg, Ψ̂)
return IC, Ψ̂
end
Expand Down
Loading

2 comments on commit 5d4a8e9

@olivierlabayle
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/105840

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.16.1 -m "<description of version>" 5d4a8e95711abfaabde310239026c6929cc8d270
git push origin v0.16.1

Please sign in to comment.