diff --git a/src/tmle_inputs/allele_independent_estimands.jl b/src/tmle_inputs/allele_independent_estimands.jl index e962872..bfabb32 100644 --- a/src/tmle_inputs/allele_independent_estimands.jl +++ b/src/tmle_inputs/allele_independent_estimands.jl @@ -19,6 +19,57 @@ function save_batch!(batch_saver::BatchManager, groupname) batch_saver.current_batch_size = 0 end +""" + generate_treatments_combinations(treatments_lists, orders) + +Generate treatment combinations for all order in orders. The final list is sorted +to encourage reuse of nuisance function in downstreatm estimation. +""" +function generate_treatments_combinations(treatments_lists, orders) + treatment_combinations = [] + for order in orders + for treatments_lists_at_order in Combinatorics.combinations(treatments_lists, order) + for treatment_comb in Iterators.product(treatments_lists_at_order...) + push!(treatment_combinations, treatment_comb) + end + end + end + return sort(treatment_combinations) +end + +function generate_interactions!(batch_saver, dataset, variants_config, outcomes, confounders; + extra_treatments=[], + outcome_extra_covariates=[], + positivity_constraint=0., + orders=[2] + ) + for (groupname, variants_dict) ∈ variants_config + treatments_lists = [Symbol.(variant_list) for variant_list in values(variants_dict)] + isempty(extra_treatments) || push!(treatments_lists, extra_treatments) + for treatments ∈ generate_treatments_combinations(treatments_lists, orders) + for outcome in outcomes + Ψ = generateIATEs(dataset, treatments, outcome, + confounders = confounders, + outcome_extra_covariates=outcome_extra_covariates, + positivity_constraint=positivity_constraint, + ) + ncomponents = length(Ψ.args) + if ncomponents > 0 + push!(batch_saver.current_estimands, Ψ) + batch_saver.current_batch_size += ncomponents + 1 + end + if batch_saver.max_batch_size !== nothing && batch_saver.current_batch_size > batch_saver.max_batch_size + save_batch!(batch_saver, groupname) + end + end + end + # Save at the end of a group + if batch_saver.current_batch_size > 0 + save_batch!(batch_saver, groupname) + end + end +end + function allele_independent_estimands(parsed_args) outprefix = parsed_args["out-prefix"] batch_saver = BatchManager(outprefix, parsed_args["batch-size"]) @@ -30,7 +81,7 @@ function allele_independent_estimands(parsed_args) config = YAML.load_file(parsed_args["allele-independent"]["config"]) # Variables - variants = config["variants"] + variants_config = config["variants"] extra_treatments = haskey(config, "extra_treatments") ? Symbol.(config["extra_treatments"]) : [] outcome_extra_covariates = haskey(config, "outcome_extra_covariates") ? Symbol.(config["outcome_extra_covariates"]) : [] extra_confounders = haskey(config, "extra_confounders") ? Symbol.(config["extra_confounders"]) : [] @@ -39,34 +90,23 @@ function allele_independent_estimands(parsed_args) outcomes = filter(x -> x ∉ nonoutcomes, Symbol.(names(traits))) # Genotypes and final dataset - variants_set = Set(TargeneCore.retrieve_variants_list(variants)) + variants_set = Set(TargeneCore.retrieve_variants_list(variants_config)) genotypes = TargeneCore.call_genotypes(bgen_prefix, variants_set, call_threshold) dataset = TargeneCore.merge(traits, pcs, genotypes) Arrow.write(string(outprefix, ".data.arrow"), dataset) # Estimands - for (groupname, variants_dict) ∈ variants - for prod ∈ Iterators.product(values(variants_dict)...) - treatments = vcat(collect(Symbol.(prod)), extra_treatments) - for outcome in outcomes - Ψ = generateIATEs(dataset, treatments, outcome, - confounders = confounders, - outcome_extra_covariates=outcome_extra_covariates, - positivity_constraint=positivity_constraint, - ) - ncomponents = length(Ψ.args) - if ncomponents > 0 - push!(batch_saver.current_estimands, Ψ) - batch_saver.current_batch_size += ncomponents + 1 - end - if batch_saver.max_batch_size !== nothing && batch_saver.current_batch_size > batch_saver.max_batch_size - save_batch!(batch_saver, groupname) - end - end - end - # Save at the end of a group - if batch_saver.current_batch_size > 0 - save_batch!(batch_saver, groupname) + for estimand_type in config["estimands"] + if estimand_type == "interactions" + orders = config["orders"] + generate_interactions!(batch_saver, dataset, variants_config, outcomes, confounders; + extra_treatments=extra_treatments, + outcome_extra_covariates=outcome_extra_covariates, + positivity_constraint=positivity_constraint, + orders=orders + ) + else + throw(ArgumentError(string("Unknown estimand type: ", estimand_type))) end end diff --git a/test/data/interaction_config.yaml b/test/data/interaction_config.yaml index f367ea4..c735346 100644 --- a/test/data/interaction_config.yaml +++ b/test/data/interaction_config.yaml @@ -1,5 +1,6 @@ -orders: [2] -estimands: interactions +orders: [2, 3] +estimands: + - interactions variants: TF1: bQTLs: diff --git a/test/tmle_inputs/allele_independent_estimands.jl b/test/tmle_inputs/allele_independent_estimands.jl index bbd258d..25de8cf 100644 --- a/test/tmle_inputs/allele_independent_estimands.jl +++ b/test/tmle_inputs/allele_independent_estimands.jl @@ -10,6 +10,34 @@ TESTDIR = joinpath(pkgdir(TargeneCore), "test") include(joinpath(TESTDIR, "tmle_inputs", "test_utils.jl")) +@testset "Test generate_treatments_combinations" begin + treatments_list = [ + [:RSID_1, :RSID_2], + [:RSID_3, :RSID_4], + [:RSID_5], + ] + order_2 = TargeneCore.generate_treatments_combinations(treatments_list, [2]) + @test order_2 == [ + (:RSID_1, :RSID_3), + (:RSID_1, :RSID_4), + (:RSID_1, :RSID_5), + (:RSID_2, :RSID_3), + (:RSID_2, :RSID_4), + (:RSID_2, :RSID_5), + (:RSID_3, :RSID_5), + (:RSID_4, :RSID_5) + ] + order_3 = TargeneCore.generate_treatments_combinations(treatments_list, [3]) + @test order_3 == [ + (:RSID_1, :RSID_3, :RSID_5), + (:RSID_1, :RSID_4, :RSID_5), + (:RSID_2, :RSID_3, :RSID_5), + (:RSID_2, :RSID_4, :RSID_5) + ] + order_2_3 = TargeneCore.generate_treatments_combinations(treatments_list, [2, 3]) + @test order_2_3 == sort(vcat(order_2, order_3)) +end + @testset "Test allele-independent: no positivity constraint" begin tmpdir = mktempdir() parsed_args = Dict( @@ -42,15 +70,19 @@ include(joinpath(TESTDIR, "tmle_inputs", "test_utils.jl")) append!(tf_estimands[:TF2], deserialize(joinpath(tmpdir, file)).estimands) end end + unique_n_components = Set{Int}([]) for (tf, estimands) ∈ tf_estimands # Number of generated estimands - ntraits, nbQTLs, neQTLs = 4, 2, 1 - @test length(estimands) == ntraits*nbQTLs*neQTLs + n_traits = 4 + n_treat_comb_order_2 = 5 + n_treat_comb_order_3 = 2 + @test length(estimands) == n_traits*(n_treat_comb_order_2+n_treat_comb_order_3) for Ψ ∈ estimands - # No positivity constraint - @test length(Ψ.args) == 9 + push!(unique_n_components, length(Ψ.args)) end end + # positivity constraint + unique_n_components == Set([3, 9]) end @testset "Test allele-independent: with positivity constraint" begin @@ -85,14 +117,16 @@ end append!(tf_estimands[:TF2], deserialize(joinpath(tmpdir, file)).estimands) end end + unique_n_components = Set{Int}([]) for (tf, estimands) ∈ tf_estimands # Number of generated estimands - @test length(estimands) == 4 < 8 + length(estimands) < 28 for Ψ ∈ estimands - # No positivity constraint - @test length(Ψ.args) == 1 < 9 + push!(unique_n_components, length(Ψ.args)) end end + # positivity constraint + unique_n_components == Set([1, 3, 5]) end end