diff --git a/src/counterfactual_mean_based/estimands.jl b/src/counterfactual_mean_based/estimands.jl index 7bdde6f..e6584d7 100644 --- a/src/counterfactual_mean_based/estimands.jl +++ b/src/counterfactual_mean_based/estimands.jl @@ -227,10 +227,20 @@ function identify(method::BackdoorAdjustment, causal_estimand::T, scm::SCM) wher ) end -unique_non_missing(dataset, colname) = unique(skipmissing(Tables.getcolumn(dataset, colname))) +function get_treatment_values(dataset, colname) + counts = groupcount(skipmissing(Tables.getcolumn(dataset, colname))) + sorted_counts = sort(collect(pairs(counts)), by = x -> x.second, rev=true) + return first.(sorted_counts) +end -unique_treatment_values(dataset, colnames) =(;(colname => unique_non_missing(dataset, colname) for colname in colnames)...) +""" + unique_treatment_values(dataset, colnames) +We ensure that the values are sorted by frequency to maximize +the number of estimands passing the positivity constraint. +""" +unique_treatment_values(dataset, colnames) = + (;(colname => get_treatment_values(dataset, colname) for colname in colnames)...) get_transitive_treatments_contrasts(treatments_unique_values) = [collect(zip(vals[1:end-1], vals[2:end])) for vals in values(treatments_unique_values)] diff --git a/test/counterfactual_mean_based/estimands.jl b/test/counterfactual_mean_based/estimands.jl index 95b9677..9d74013 100644 --- a/test/counterfactual_mean_based/estimands.jl +++ b/test/counterfactual_mean_based/estimands.jl @@ -2,7 +2,6 @@ module TestEstimands using Test using TMLE - @testset "Test StatisticalCMCompositeEstimand" begin dataset = ( W = [1, 2, 3, 4, 5, 6, 7, 8], @@ -208,6 +207,19 @@ end treatments_unique_values = (T₁=(1, 0, 2), T₂=["AC", "CC"]) @test TMLE.get_transitive_treatments_contrasts(treatments_unique_values) == [[(1, 0), (0, 2)], [("AC", "CC")]] end + +@testset "Test unique_treatment_values" begin + dataset = ( + T₁ = ["AC", missing, "AC", "CC", "CC", "AA", "CC"], + T₂ = [1, missing, 1, 2, 2, 3, 2] + ) + # most frequent to least frequent + @test TMLE.unique_treatment_values(dataset, (:T₁, :T₂)) == ( + T₁ = ["CC", "AC", "AA"], + T₂ = [2, 1, 3], + ) +end + @testset "Test factorialATE" begin dataset = ( T₁ = [0, 1, 2, missing],