Skip to content

Commit

Permalink
Merge pull request #103 from TARGENE/significance_test
Browse files Browse the repository at this point in the history
Significance test
  • Loading branch information
olivierlabayle authored Jan 31, 2024
2 parents c1e044c + 1f6db9e commit 58c2a3a
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 32 deletions.
6 changes: 2 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TMLE"
uuid = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf"
authors = ["Olivier Labayle"]
version = "0.14.0"
version = "0.14.1"

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
Expand All @@ -19,7 +19,6 @@ MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377"
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down Expand Up @@ -51,12 +50,11 @@ MLJModels = "0.15, 0.16"
MetaGraphsNext = "0.7"
Missings = "1.0"
PrecompileTools = "1.1.1"
PrettyTables = "2.2"
SplitApplyCombine = "1.2.2"
TableOperations = "1.2"
Tables = "1.6"
YAML = "0.4.9"
Zygote = "0.6.69"
SplitApplyCombine = "1.2.2"
julia = "1.6, 1.7, 1"

[extras]
Expand Down
4 changes: 2 additions & 2 deletions src/TMLE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ using Distributions
using Zygote
using LogExpFunctions
using PrecompileTools
using PrettyTables
using Random
import AbstractDifferentiation as AD
using Graphs
Expand All @@ -32,7 +31,8 @@ export AVAILABLE_ESTIMANDS
export factorialATE, factorialIATE
export TMLEE, OSE, NAIVE
export ComposedEstimand
export var, estimate, OneSampleTTest, OneSampleZTest, OneSampleHotellingT2Test,pvalue, confint, emptyIC
export var, estimate, pvalue, confint, emptyIC
export significance_test, OneSampleTTest, OneSampleZTest, OneSampleHotellingT2Test
export compose
export TreatmentTransformer, with_encoder, encoder
export BackdoorAdjustment, identify
Expand Down
16 changes: 9 additions & 7 deletions src/counterfactual_mean_based/estimates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,6 @@ end

emptyIC(estimate; pval_threshold=nothing) = emptyIC(estimate, pval_threshold)


function Base.show(io::IO, ::MIME"text/plain", est::EICEstimate)
testresult = OneSampleTTest(est)
data = [estimate(est) confint(testresult) pvalue(testresult);]
pretty_table(io, data;header=["Estimate", "95% Confidence Interval", "P-value"])
end

"""
Distributions.estimate(r::EICEstimate)
Expand Down Expand Up @@ -104,3 +97,12 @@ Performs a T test on the EICEstimate.
HypothesisTests.OneSampleTTest(est::EICEstimate, Ψ₀=0) =
OneSampleTTest(est.estimate, est.std, est.n, Ψ₀)

"""
significance_test(estimate::EICEstimate, Ψ₀=0)
Performs a TTest
"""
significance_test(estimate::EICEstimate, Ψ₀=0) = OneSampleTTest(estimate, Ψ₀)

Base.show(io::IO, mime::MIME"text/plain", est::Union{EICEstimate, ComposedEstimand}) =
show(io, mime, significance_test(est))
26 changes: 14 additions & 12 deletions src/estimates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,18 +109,6 @@ to_matrix(x) = reduce(hcat, x)
ComposedEstimate(;estimand, estimates, estimate, cov, n) =
ComposedEstimate(estimand, Tuple(estimates), collect(estimate), to_matrix(cov), n)


function Base.show(io::IO, ::MIME"text/plain", est::ComposedEstimate)
if length(est.cov) !== 1
println(io, string("Estimate: ", estimate(est), "\nVariance: \n", var(est)))
else
testresult = OneSampleTTest(est)
data = [estimate(est) confint(testresult) pvalue(testresult);]
headers = ["Estimate", "95% Confidence Interval", "P-value"]
pretty_table(io, data;header=headers)
end
end

"""
Distributions.estimate(r::ComposedEstimate)
Expand Down Expand Up @@ -171,6 +159,20 @@ function HypothesisTests.OneSampleZTest(estimate::ComposedEstimate, Ψ₀=0)
return OneSampleZTest(estimate.estimate[1], sqrt(estimate.cov[1]), estimate.n, Ψ₀)
end

"""
significance_test(estimate::ComposedEstimate, Ψ₀=zeros(size(estimate.estimate, 1)))
Performs a TTest if the estimate is one dimensional and a HotellingT2Test otherwise.
"""
function significance_test(estimate::ComposedEstimate, Ψ₀=zeros(size(estimate.estimate, 1)))
if length(estimate.estimate) == 1
Ψ₀ = Ψ₀ isa AbstractArray ? first(Ψ₀) : Ψ₀
return OneSampleTTest(estimate, Ψ₀)
else
return OneSampleHotellingT2Test(estimate, Ψ₀)
end
end

function emptyIC(estimate::ComposedEstimate, pval_threshold)
emptied_estimates = Tuple(emptyIC(e, pval_threshold) for e in estimate.estimates)
ComposedEstimate(estimate.estimand, emptied_estimates, estimate.estimate, estimate.cov, estimate.n)
Expand Down
4 changes: 2 additions & 2 deletions test/composition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ end
ose = OSE(models=TMLE.default_models(G=LogisticClassifier(), Q_continuous=LinearRegressor()))
jointEstimate, _ = ose(jointIATE, dataset, verbosity=0)

testres = OneSampleHotellingT2Test(jointEstimate)
testres = significance_test(jointEstimate)
@test testres. jointEstimate.estimate
@test pvalue(testres) < 1e-10

Expand All @@ -213,7 +213,7 @@ end
maybe_emptied_estimate = TMLE.emptyIC(jointEstimate, pval_threshold=pval_threshold)
n_empty = 0
for i in 1:3
pval = pvalue(OneSampleTTest(jointEstimate.estimates[i]))
pval = pvalue(significance_test(jointEstimate.estimates[i]))
maybe_emptied_IC = maybe_emptied_estimate.estimates[i].IC
if pval > pval_threshold
@test maybe_emptied_IC == []
Expand Down
2 changes: 1 addition & 1 deletion test/counterfactual_mean_based/non_regression_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using YAML

function regression_tests(tmle_result)
@test estimate(tmle_result) -0.185533 atol = 1e-6
l, u = confint(OneSampleTTest(tmle_result))
l, u = confint(significance_test(tmle_result))
@test l -0.279246 atol = 1e-6
@test u -0.091821 atol = 1e-6
@test OneSampleZTest(tmle_result) isa OneSampleZTest
Expand Down
5 changes: 1 addition & 4 deletions test/helper_fns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,7 @@ at the given confidence level: here 0.05
"""
function test_coverage(result::TMLE.EICEstimate, Ψ₀)
# TMLE
lb, ub = confint(OneSampleTTest(result))
@test lb Ψ₀ ub
# OneStep
lb, ub = confint(OneSampleZTest(result))
lb, ub = confint(significance_test(result))
@test lb Ψ₀ ub
end

Expand Down

2 comments on commit 58c2a3a

@olivierlabayle
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

  • Add significance_test function
  • Change show methods of Estimates to return test result

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/99939

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.14.1 -m "<description of version>" 58c2a3a292398e3dd8ff64c3f4d0df2034542c27
git push origin v0.14.1

Please sign in to comment.