Skip to content

Commit

Permalink
Merge pull request #112 from TARGENE/treatment_values
Browse files Browse the repository at this point in the history
For 0.17.0 release
  • Loading branch information
olivierlabayle authored Aug 21, 2024
2 parents 5d4a8e9 + e9aca43 commit 248bc05
Show file tree
Hide file tree
Showing 36 changed files with 773 additions and 684 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- '1.10'
- '1'
os:
- ubuntu-latest
Expand Down
10 changes: 7 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "TMLE"
uuid = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf"
authors = ["Olivier Labayle"]
version = "0.16.1"
version = "0.17.0"

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f"
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -18,6 +19,7 @@ MLJGLMInterface = "caf8df21-4939-456d-ac9c-5fefbfb04c0c"
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377"
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
Expand Down Expand Up @@ -46,7 +48,7 @@ JSON = "0.21.4"
LogExpFunctions = "0.3"
MLJBase = "1.0.1"
MLJGLMInterface = "0.3.4"
MLJModels = "0.15, 0.16"
MLJModels = "0.15, 0.16, 0.17"
MetaGraphsNext = "0.7"
Missings = "1.0"
PrecompileTools = "1.1.1"
Expand All @@ -55,7 +57,9 @@ TableOperations = "1.2"
Tables = "1.6"
YAML = "0.4.9"
Zygote = "0.6.69"
julia = "1.6, 1.7, 1"
OrderedCollections = "1.6.3"
AutoHashEquals = "2.1.0"
julia = "1.10, 1"

[extras]
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
Expand Down
42 changes: 27 additions & 15 deletions docs/src/user_guide/estimands.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,7 @@ statisticalΨ = ATE(
)
```

- Factorial Treatments

It is possible to generate a `ComposedEstimand` containing all linearly independent IATEs from a set of treatment values or from a dataset. For that purpose, use the `factorialEstimand` function.

## The Interaction Average Treatment Effect
## The Average Interaction Effect

- Causal Question:

Expand All @@ -136,14 +132,14 @@ For a general higher-order definition, please refer to [Higher-order interaction
For two points interaction with both treatment and control levels ``0`` and ``1`` for ease of notation:

```math
IATE_{0 \rightarrow 1, 0 \rightarrow 1}(P) = \mathbb{E}[Y|do(T_1=1, T_2=1)] - \mathbb{E}[Y|do(T_1=1, T_2=0)] \\
AIE_{0 \rightarrow 1, 0 \rightarrow 1}(P) = \mathbb{E}[Y|do(T_1=1, T_2=1)] - \mathbb{E}[Y|do(T_1=1, T_2=0)] \\
- \mathbb{E}[Y|do(T_1=0, T_2=1)] + \mathbb{E}[Y|do(T_1=0, T_2=0)]
```

- Statistical Estimand (via backdoor adjustment):

```math
IATE_{0 \rightarrow 1, 0 \rightarrow 1}(P) = \mathbb{E}_{\textbf{W}}[\mathbb{E}[Y|T_1=1, T_2=1, \textbf{W}] - \mathbb{E}[Y|T_1=1, T_2=0, \textbf{W}] \\
AIE_{0 \rightarrow 1, 0 \rightarrow 1}(P) = \mathbb{E}_{\textbf{W}}[\mathbb{E}[Y|T_1=1, T_2=1, \textbf{W}] - \mathbb{E}[Y|T_1=1, T_2=0, \textbf{W}] \\
- \mathbb{E}[Y|T_1=0, T_2=1, \textbf{W}] + \mathbb{E}[Y|T_1=0, T_2=0, \textbf{W}]]
```

Expand All @@ -152,7 +148,7 @@ IATE_{0 \rightarrow 1, 0 \rightarrow 1}(P) = \mathbb{E}_{\textbf{W}}[\mathbb{E}[
A causal estimand is given by:

```@example estimands
causalΨ = IATE(
causalΨ = AIE(
outcome=:Y,
treatment_values=(
T₁=(case=1, control=0),
Expand All @@ -170,7 +166,7 @@ statisticalΨ = identify(causalΨ, scm)
or defined directly:

```@example estimands
statisticalΨ = IATE(
statisticalΨ = AIE(
outcome=:Y,
treatment_values=(
T₁=(case=1, control=0),
Expand All @@ -182,13 +178,11 @@ statisticalΨ = IATE(

- Factorial Treatments

It is possible to generate a `ComposedEstimand` containing all linearly independent IATEs from a set of treatment values or from a dataset. For that purpose, use the `factorialEstimand` function.

## Composed Estimands
It is possible to generate a `JointEstimand` containing all linearly independent AIEs from a set of treatment values or from a dataset. For that purpose, use the `factorialEstimand` function.

As a result of Julia's automatic differentiation facilities, given a set of predefined estimands ``(\Psi_1, ..., \Psi_k)``, we can automatically compute an estimator for $f(\Psi_1, ..., \Psi_k)$. This is done via the `ComposedEstimand` type.
## Joint And Composed Estimands

For example, the difference in ATE for a treatment with 3 levels (0, 1, 2) can be defined as follows:
A `JointEstimand` is simply a list of one dimensional estimands that are grouped together. For instance for a treatment `T` taking three possible values ``(0, 1, 2)`` we can define the two following Average Treatment Effects and a corresponding `JointEstimand`:

```julia
ATE₁ = ATE(
Expand All @@ -201,5 +195,23 @@ ATE₂ = ATE(
treatment_values = (T = (control = 1, case = 2),),
treatment_confounders = [:W]
)
ATEdiff = ComposedEstimand(-, (ATE₁, ATE₂))
joint_estimand = JointEstimand(ATE₁, ATE₂)
```

You can easily generate joint estimands corresponding to Counterfactual Means, Average Treatment Effects or Average Interaction Effects by using the `factorialEstimand` function.

To estimate a joint estimand you can use any of the estimators defined in this package exactly as you would do it for a one dimensional estimand.

There are two main use cases for them that we now describe.

### Joint Testing

In some cases, like in factorial analyses where multiple versions of a treatment are tested, it may be of interest to know if any version of the versions has had an effect. This can be done via a Hotelling's T2 Test, which is simply a multivariate generalisation of the Student's T test. This is the default returned by the `significance_test` function provided in TMLE.jl and the result of the test is also printed to the REPL for any joint estimate.

### Composition

Once you have estimated a `JointEstimand` and have a `JointEstimate`, you may be interested to ask further questions. For instance whether two treatment versions have the same effect. This question is typically answered by testing if the difference in Average Treatment Effect is 0. Using the Delta Method and Julia's automatic differentiation, you don't need to explicitly define a semi-parametric estimator for it. You can simply call `compose`:

```julia
ATEdiff = compose(-, joint_estimate)
```
65 changes: 38 additions & 27 deletions docs/src/user_guide/estimation.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ Again, required nuisance functions are fitted and stored in the cache.

## Specifying Models

By default, TMLE.jl uses generalized linear models for the estimation of relevant and nuisance factors such as the outcome mean and the propensity score. However, this is not the recommended usage since the estimators' performance is closely related to how well we can estimate these factors. More sophisticated models can be provided using the `models` keyword argument of each estimator which is essentially a `NamedTuple` mapping variables' names to their respective model.
By default, TMLE.jl uses generalized linear models for the estimation of relevant and nuisance factors such as the outcome mean and the propensity score. However, this is not the recommended usage since the estimators' performance is closely related to how well we can estimate these factors. More sophisticated models can be provided using the `models` keyword argument of each estimator which is a `Dict{Symbol, Model}` mapping variables' names to their respective model.

Rather than specifying a specific model for each variable it may be easier to override the default models using the `default_models` function:

Expand All @@ -121,9 +121,9 @@ using MLJXGBoostInterface
xgboost_regressor = XGBoostRegressor()
xgboost_classifier = XGBoostClassifier()
models = default_models(
Q_binary=xgboost_classifier,
Q_continuous=xgboost_regressor,
G=xgboost_classifier
Q_binary = xgboost_classifier,
Q_continuous = xgboost_regressor,
G = xgboost_classifier
)
tmle_gboost = TMLEE(models=models)
```
Expand All @@ -140,19 +140,18 @@ stack_binary = Stack(
lr=lr
)
models = (
T₁ = with_encoder(xgboost_classifier), # T₁ with XGBoost prepended with a Continuous Encoder
default_models( # For all other variables use the following defaults
Q_binary=stack_binary, # A Super Learner
Q_continuous=xgboost_regressor, # An XGBoost
models = default_models( # For all non-specified variables use the following defaults
Q_binary = stack_binary, # A Super Learner
Q_continuous = xgboost_regressor, # An XGBoost
# T₁ with XGBoost prepended with a Continuous Encoder
T₁ = xgboost_classifier
# Unspecified G defaults to Logistic Regression
)...
)
tmle_custom = TMLEE(models=models)
```

Notice that `with_encoder` is simply a shorthand to construct a pipeline with a `ContinuousEncoder` and that the resulting `models` is simply a `NamedTuple`.
Notice that `with_encoder` is simply a shorthand to construct a pipeline with a `ContinuousEncoder` and that the resulting `models` is simply a `Dict`.

## CV-Estimation

Expand Down Expand Up @@ -196,10 +195,10 @@ result₃
nothing # hide
```

This time only the model for `Y` is fitted again while reusing the models for `T₁` and `T₂`. Finally, let's see what happens if we estimate the `IATE` between `T₁` and `T₂`.
This time only the model for `Y` is fitted again while reusing the models for `T₁` and `T₂`. Finally, let's see what happens if we estimate the `AIE` between `T₁` and `T₂`.

```@example estimation
Ψ₄ = IATE(
Ψ₄ = AIE(
outcome=:Y,
treatment_values=(
T₁=(case=true, control=false),
Expand All @@ -218,18 +217,20 @@ nothing # hide

All nuisance functions have been reused, only the fluctuation is fitted!

## Composing Estimands
## Joint Estimands and Composition

By leveraging the multivariate Central Limit Theorem and Julia's automatic differentiation facilities, we can estimate any estimand which is a function of already estimated estimands. By default, TMLE.jl will use [Zygote](https://fluxml.ai/Zygote.jl/latest/) but since we are using [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) you can change the backend to your favorite AD system.
As explained in [Joint And Composed Estimands](@ref), a joint estimand is simply a collection of estimands. Here, we will illustrate that an Average Interaction Effect is also defined as a difference in partial Average Treatment Effects.

For instance, by definition of the ``IATE``, we should be able to retrieve:
More precisely, we would like to see if the left-hand side of this equation is equal to the right-hand side:

```math
IATE_{T_1=0 \rightarrow 1, T_2=0 \rightarrow 1} = ATE_{T_1=0 \rightarrow 1, T_2=0 \rightarrow 1} - ATE_{T_1=0, T_2=0 \rightarrow 1} - ATE_{T_1=0 \rightarrow 1, T_2=0}
AIE_{T_1=0 \rightarrow 1, T_2=0 \rightarrow 1} = ATE_{T_1=0 \rightarrow 1, T_2=0 \rightarrow 1} - ATE_{T_1=0, T_2=0 \rightarrow 1} - ATE_{T_1=0 \rightarrow 1, T_2=0}
```

For that, we need to define a joint estimand of three components:

```@example estimation
first_ate = ATE(
ATE₁ = ATE(
outcome=:Y,
treatment_values=(
T₁=(case=true, control=false),
Expand All @@ -239,9 +240,7 @@ first_ate = ATE(
T₂=[:W₂₁, :W₂₂],
),
)
first_ate_result, cache = tmle(first_ate, dataset, cache=cache, verbosity=0);
second_ate = ATE(
ATE₂ = ATE(
outcome=:Y,
treatment_values=(
T₁=(case=false, control=false),
Expand All @@ -251,15 +250,27 @@ second_ate = ATE(
T₂=[:W₂₁, :W₂₂],
),
)
second_ate_result, cache = tmle(second_ate, dataset, cache=cache, verbosity=0);
joint_estimand = JointEstimand(Ψ₃, ATE₁, ATE₂)
```

composed_iate_result = compose(
(x, y, z) -> x - y - z,
result₃, first_ate_result, second_ate_result
)
where the interaction `Ψ₃` was defined earlier. This joint estimand can be estimated like any other estimand using our estimator of choice:

```@example estimation
joint_estimate, cache = tmle(joint_estimand, dataset, cache=cache, verbosity=0);
joint_estimate
```

The printed output is the result of a Hotelling's T2 Test which is the multivariate counterpart of the Student's T Test. It tells us whether any of the component of this joint estimand is different from 0.

Then we can formally test our hypothesis by leveraging the multivariate Central Limit Theorem and Julia's automatic differentiation.

```@example estimation
composed_result = compose((x, y, z) -> x - y - z, joint_estimate)
isapprox(
estimate(result₄),
estimate(composed_iate_result),
first(estimate(composed_result)),
atol=0.1
)
```

By default, TMLE.jl will use [Zygote](https://fluxml.ai/Zygote.jl/latest/) but since we are using [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) you can change the backend to your favorite AD system.
10 changes: 6 additions & 4 deletions docs/src/walk_through.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ marginal_ate_t1 = ATE(
)
```

- The Interaction Average Treatment Effect:
- The Average Interaction Effect:

```@example walk-through
iate = IATE(
aie = AIE(
outcome = :Y,
treatment_values = (
T₁=(case=1, control=0),
Expand All @@ -125,7 +125,7 @@ iate = IATE(
Identification is the process by which a Causal Estimand is turned into a Statistical Estimand, that is, a quantity we may estimate from data. This is done via the `identify` function which also takes in the ``SCM``:

```@example walk-through
statistical_iate = identify(iate, scm)
statistical_aie = identify(aie, scm)
```

Alternatively, you can also directly define the statistical parameters (see [Estimands](@ref)).
Expand All @@ -149,7 +149,7 @@ Statistical Estimands can be estimated without a ``SCM``, let's use the One-Step

```@example walk-through
ose = OSE()
result, cache = ose(statistical_iate, dataset)
result, cache = ose(statistical_aie, dataset)
result
```

Expand All @@ -160,3 +160,5 @@ Both TMLE and OSE asymptotically follow a Normal distribution. It means we can p
```@example walk-through
OneSampleTTest(result)
```

If the estimate is high-dimensional, a `OneSampleHotellingT2Test` should be performed instead. Alternatively, the `significance_test` function will automatically select the appropriate test for the estimate and return its result.
6 changes: 3 additions & 3 deletions examples/double_robustness.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ function tmle_inference(data)
treatment_values=(Tcat=(case=1.0, control=0.0),),
treatment_confounders=(Tcat=[:W],)
)
models = (
Y = with_encoder(LinearRegressor()),
Tcat = with_encoder(LinearBinaryClassifier())
models = Dict(
:Y => with_encoder(LinearRegressor()),
:Tcat => with_encoder(LinearBinaryClassifier())
)
tmle = TMLEE(models=models)
result, _ = tmle(Ψ, data; verbosity=0)
Expand Down
10 changes: 6 additions & 4 deletions src/TMLE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,19 @@ using Graphs
using MetaGraphsNext
using Combinatorics
using SplitApplyCombine
using OrderedCollections
using AutoHashEquals

# #############################################################################
# EXPORTS
# #############################################################################

export SCM, StaticSCM, add_equations!, add_equation!, parents, vertices
export CM, ATE, IATE
export CM, ATE, AIE
export AVAILABLE_ESTIMANDS
export factorialEstimand, factorialEstimands
export TMLEE, OSE, NAIVE
export ComposedEstimand
export JointEstimand, ComposedEstimand
export var, estimate, pvalue, confint, emptyIC
export significance_test, OneSampleTTest, OneSampleZTest, OneSampleHotellingT2Test
export compose
Expand All @@ -48,8 +50,8 @@ include("utils.jl")
include("scm.jl")
include("adjustment.jl")
include("estimands.jl")
include("estimators.jl")
include("estimates.jl")
include("estimators.jl")
include("treatment_transformer.jl")
include("estimand_ordering.jl")

Expand All @@ -61,6 +63,6 @@ include("counterfactual_mean_based/clever_covariate.jl")
include("counterfactual_mean_based/gradient.jl")

include("configuration.jl")

include("testing.jl")

end
4 changes: 2 additions & 2 deletions src/configuration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ from_dict!(x) = x
from_dict!(v::AbstractVector) = [from_dict!(x) for x in v]

"""
from_dict!(d::Dict)
from_dict!(d::AbstractDict)
Converts a dictionary to a TMLE struct.
"""
function from_dict!(d::Dict{T, Any}) where T
function from_dict!(d::AbstractDict{T, Any}) where T
haskey(d, T(:type)) || return Dict(key => from_dict!(val) for (key, val) in d)
constructor = eval(Meta.parse(pop!(d, :type)))
return constructor(;(key => from_dict!(val) for (key, val) in d)...)
Expand Down
Loading

2 comments on commit 248bc05

@olivierlabayle
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/113576

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.17.0 -m "<description of version>" 248bc0522f51c14941f2191e8d0619cf3d2ece1b
git push origin v0.17.0

Please sign in to comment.