Skip to content

Commit

Permalink
add more general interaction generation
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Jan 12, 2024
1 parent 25cdab7 commit feccc8a
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 33 deletions.
88 changes: 64 additions & 24 deletions src/tmle_inputs/allele_independent_estimands.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,57 @@ function save_batch!(batch_saver::BatchManager, groupname)
batch_saver.current_batch_size = 0
end

"""
generate_treatments_combinations(treatments_lists, orders)
Generate treatment combinations for all order in orders. The final list is sorted
to encourage reuse of nuisance function in downstreatm estimation.
"""
function generate_treatments_combinations(treatments_lists, orders)
treatment_combinations = []
for order in orders
for treatments_lists_at_order in Combinatorics.combinations(treatments_lists, order)
for treatment_comb in Iterators.product(treatments_lists_at_order...)
push!(treatment_combinations, treatment_comb)
end
end
end
return sort(treatment_combinations)
end

function generate_interactions!(batch_saver, dataset, variants_config, outcomes, confounders;
extra_treatments=[],
outcome_extra_covariates=[],
positivity_constraint=0.,
orders=[2]
)
for (groupname, variants_dict) variants_config
treatments_lists = [Symbol.(variant_list) for variant_list in values(variants_dict)]
isempty(extra_treatments) || push!(treatments_lists, extra_treatments)
for treatments generate_treatments_combinations(treatments_lists, orders)
for outcome in outcomes
Ψ = generateIATEs(dataset, treatments, outcome,
confounders = confounders,
outcome_extra_covariates=outcome_extra_covariates,
positivity_constraint=positivity_constraint,
)
ncomponents = length.args)
if ncomponents > 0
push!(batch_saver.current_estimands, Ψ)
batch_saver.current_batch_size += ncomponents + 1
end
if batch_saver.max_batch_size !== nothing && batch_saver.current_batch_size > batch_saver.max_batch_size
save_batch!(batch_saver, groupname)
end
end
end
# Save at the end of a group
if batch_saver.current_batch_size > 0
save_batch!(batch_saver, groupname)
end
end
end

function allele_independent_estimands(parsed_args)
outprefix = parsed_args["out-prefix"]
batch_saver = BatchManager(outprefix, parsed_args["batch-size"])
Expand All @@ -30,7 +81,7 @@ function allele_independent_estimands(parsed_args)
config = YAML.load_file(parsed_args["allele-independent"]["config"])

# Variables
variants = config["variants"]
variants_config = config["variants"]
extra_treatments = haskey(config, "extra_treatments") ? Symbol.(config["extra_treatments"]) : []
outcome_extra_covariates = haskey(config, "outcome_extra_covariates") ? Symbol.(config["outcome_extra_covariates"]) : []
extra_confounders = haskey(config, "extra_confounders") ? Symbol.(config["extra_confounders"]) : []
Expand All @@ -39,34 +90,23 @@ function allele_independent_estimands(parsed_args)
outcomes = filter(x -> x nonoutcomes, Symbol.(names(traits)))

# Genotypes and final dataset
variants_set = Set(TargeneCore.retrieve_variants_list(variants))
variants_set = Set(TargeneCore.retrieve_variants_list(variants_config))
genotypes = TargeneCore.call_genotypes(bgen_prefix, variants_set, call_threshold)
dataset = TargeneCore.merge(traits, pcs, genotypes)
Arrow.write(string(outprefix, ".data.arrow"), dataset)

# Estimands
for (groupname, variants_dict) variants
for prod Iterators.product(values(variants_dict)...)
treatments = vcat(collect(Symbol.(prod)), extra_treatments)
for outcome in outcomes
Ψ = generateIATEs(dataset, treatments, outcome,
confounders = confounders,
outcome_extra_covariates=outcome_extra_covariates,
positivity_constraint=positivity_constraint,
)
ncomponents = length.args)
if ncomponents > 0
push!(batch_saver.current_estimands, Ψ)
batch_saver.current_batch_size += ncomponents + 1
end
if batch_saver.max_batch_size !== nothing && batch_saver.current_batch_size > batch_saver.max_batch_size
save_batch!(batch_saver, groupname)
end
end
end
# Save at the end of a group
if batch_saver.current_batch_size > 0
save_batch!(batch_saver, groupname)
for estimand_type in config["estimands"]
if estimand_type == "interactions"
orders = config["orders"]
generate_interactions!(batch_saver, dataset, variants_config, outcomes, confounders;
extra_treatments=extra_treatments,
outcome_extra_covariates=outcome_extra_covariates,
positivity_constraint=positivity_constraint,
orders=orders
)
else
throw(ArgumentError(string("Unknown estimand type: ", estimand_type)))
end
end

Expand Down
5 changes: 3 additions & 2 deletions test/data/interaction_config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
orders: [2]
estimands: interactions
orders: [2, 3]
estimands:
- interactions
variants:
TF1:
bQTLs:
Expand Down
48 changes: 41 additions & 7 deletions test/tmle_inputs/allele_independent_estimands.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,34 @@ TESTDIR = joinpath(pkgdir(TargeneCore), "test")

include(joinpath(TESTDIR, "tmle_inputs", "test_utils.jl"))

@testset "Test generate_treatments_combinations" begin
treatments_list = [
[:RSID_1, :RSID_2],
[:RSID_3, :RSID_4],
[:RSID_5],
]
order_2 = TargeneCore.generate_treatments_combinations(treatments_list, [2])
@test order_2 == [
(:RSID_1, :RSID_3),
(:RSID_1, :RSID_4),
(:RSID_1, :RSID_5),
(:RSID_2, :RSID_3),
(:RSID_2, :RSID_4),
(:RSID_2, :RSID_5),
(:RSID_3, :RSID_5),
(:RSID_4, :RSID_5)
]
order_3 = TargeneCore.generate_treatments_combinations(treatments_list, [3])
@test order_3 == [
(:RSID_1, :RSID_3, :RSID_5),
(:RSID_1, :RSID_4, :RSID_5),
(:RSID_2, :RSID_3, :RSID_5),
(:RSID_2, :RSID_4, :RSID_5)
]
order_2_3 = TargeneCore.generate_treatments_combinations(treatments_list, [2, 3])
@test order_2_3 == sort(vcat(order_2, order_3))
end

@testset "Test allele-independent: no positivity constraint" begin
tmpdir = mktempdir()
parsed_args = Dict(
Expand Down Expand Up @@ -42,15 +70,19 @@ include(joinpath(TESTDIR, "tmle_inputs", "test_utils.jl"))
append!(tf_estimands[:TF2], deserialize(joinpath(tmpdir, file)).estimands)
end
end
unique_n_components = Set{Int}([])
for (tf, estimands) tf_estimands
# Number of generated estimands
ntraits, nbQTLs, neQTLs = 4, 2, 1
@test length(estimands) == ntraits*nbQTLs*neQTLs
n_traits = 4
n_treat_comb_order_2 = 5
n_treat_comb_order_3 = 2
@test length(estimands) == n_traits*(n_treat_comb_order_2+n_treat_comb_order_3)
for Ψ estimands
# No positivity constraint
@test length.args) == 9
push!(unique_n_components, length.args))
end
end
# positivity constraint
unique_n_components == Set([3, 9])
end

@testset "Test allele-independent: with positivity constraint" begin
Expand Down Expand Up @@ -85,14 +117,16 @@ end
append!(tf_estimands[:TF2], deserialize(joinpath(tmpdir, file)).estimands)
end
end
unique_n_components = Set{Int}([])
for (tf, estimands) tf_estimands
# Number of generated estimands
@test length(estimands) == 4 < 8
length(estimands) < 28
for Ψ estimands
# No positivity constraint
@test length.args) == 1 < 9
push!(unique_n_components, length.args))
end
end
# positivity constraint
unique_n_components == Set([1, 3, 5])
end

end
Expand Down

0 comments on commit feccc8a

Please sign in to comment.