diff --git a/Manifest.toml b/Manifest.toml index 4fecafe..9c00d6a 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -331,7 +331,9 @@ weakdeps = ["JSON", "RecipesBase", "SentinelArrays", "StructTypes"] [[deps.CategoricalDistributions]] deps = ["CategoricalArrays", "Distributions", "Missings", "OrderedCollections", "Random", "ScientificTypes"] git-tree-sha1 = "926862f549a82d6c3a7145bc7f1adff2a91a39f0" +git-tree-sha1 = "926862f549a82d6c3a7145bc7f1adff2a91a39f0" uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e" +version = "0.1.15" version = "0.1.15" [deps.CategoricalDistributions.extensions] @@ -492,8 +494,10 @@ version = "0.17.6" [[deps.ConstructionBase]] deps = ["LinearAlgebra"] git-tree-sha1 = "260fd2400ed2dab602a7c15cf10c1933c59930a2" +git-tree-sha1 = "260fd2400ed2dab602a7c15cf10c1933c59930a2" uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" version = "1.5.5" +version = "1.5.5" weakdeps = ["IntervalSets", "StaticArrays"] [deps.ConstructionBase.extensions] @@ -508,8 +512,10 @@ version = "0.1.3" [[deps.Contour]] git-tree-sha1 = "439e35b0b36e2e5881738abc8857bd92ad6ff9a8" +git-tree-sha1 = "439e35b0b36e2e5881738abc8857bd92ad6ff9a8" uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" version = "0.6.3" +version = "0.6.3" [[deps.CpuId]] deps = ["Markdown"] @@ -656,7 +662,9 @@ version = "1.0.4" [[deps.EvoTrees]] deps = ["BSON", "CategoricalArrays", "Distributions", "MLJModelInterface", "NetworkLayout", "Random", "RecipesBase", "Statistics", "StatsBase", "Tables"] git-tree-sha1 = "92d1f78f95f4794bf29bd972dacfa37ea1fec9f4" +git-tree-sha1 = "92d1f78f95f4794bf29bd972dacfa37ea1fec9f4" uuid = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" +version = "0.16.7" version = "0.16.7" [deps.EvoTrees.extensions] @@ -731,8 +739,10 @@ version = "0.1.1" [[deps.FileIO]] deps = ["Pkg", "Requires", "UUIDs"] git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322" +git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322" uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" version = "1.16.3" +version = "1.16.3" [[deps.FilePaths]] deps = ["FilePathsBase", "MacroTools", "Reexport", "Requires"] @@ -791,8 +801,10 @@ version = "2.13.96+0" [[deps.Format]] git-tree-sha1 = "9c68794ef81b08086aeb32eeaf33531668d5f5fc" +git-tree-sha1 = "9c68794ef81b08086aeb32eeaf33531668d5f5fc" uuid = "1fa38f19-a742-5d3f-a2b9-30dd87b9d5f8" version = "1.3.7" +version = "1.3.7" [[deps.ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] @@ -912,8 +924,10 @@ version = "1.11.0" [[deps.GridLayoutBase]] deps = ["GeometryBasics", "InteractiveUtils", "Observables"] git-tree-sha1 = "6f93a83ca11346771a93bbde2bdad2f65b61498f" +git-tree-sha1 = "6f93a83ca11346771a93bbde2bdad2f65b61498f" uuid = "3955a311-db13-416c-9275-1d80ed98e5e9" version = "0.10.2" +version = "0.10.2" [[deps.Grisu]] git-tree-sha1 = "53bb909d1151e57e2484c3d1b53e19552b887fb2" @@ -1219,7 +1233,9 @@ version = "3.100.2+0" [[deps.LLVM]] deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] git-tree-sha1 = "839c82932db86740ae729779e610f07a1640be9a" +git-tree-sha1 = "839c82932db86740ae729779e610f07a1640be9a" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" +version = "6.6.3" version = "6.6.3" [deps.LLVM.extensions] @@ -1484,8 +1500,10 @@ version = "0.10.0" [[deps.MLJModelInterface]] deps = ["Random", "ScientificTypesBase", "StatisticalTraits"] git-tree-sha1 = "d2a45e1b5998ba3fdfb6cfe0c81096d4c7fb40e7" +git-tree-sha1 = "d2a45e1b5998ba3fdfb6cfe0c81096d4c7fb40e7" uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" version = "1.9.6" +version = "1.9.6" [[deps.MLJModels]] deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] @@ -1751,8 +1769,10 @@ version = "1.4.3" [[deps.OpenSSL_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] git-tree-sha1 = "3da7367955dcc5c54c1ba4d402ccdc09a1a3e046" +git-tree-sha1 = "3da7367955dcc5c54c1ba4d402ccdc09a1a3e046" uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" version = "3.0.13+1" +version = "3.0.13+1" [[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] @@ -1763,12 +1783,21 @@ version = "0.5.5+0" [[deps.Optim]] deps = ["Compat", "FillArrays", "ForwardDiff", "LineSearches", "LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "PositiveFactorizations", "Printf", "SparseArrays", "StatsBase"] git-tree-sha1 = "d9b79c4eed437421ac4285148fcadf42e0700e89" +deps = ["Compat", "FillArrays", "ForwardDiff", "LineSearches", "LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "PositiveFactorizations", "Printf", "SparseArrays", "StatsBase"] +git-tree-sha1 = "d9b79c4eed437421ac4285148fcadf42e0700e89" uuid = "429524aa-4258-5aef-a3af-852621145aeb" version = "1.9.4" [deps.Optim.extensions] OptimMOIExt = "MathOptInterface" + [deps.Optim.weakdeps] + MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" +version = "1.9.4" + + [deps.Optim.extensions] + OptimMOIExt = "MathOptInterface" + [deps.Optim.weakdeps] MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" @@ -2283,8 +2312,10 @@ version = "1.7.0" [[deps.StatsBase]] deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" +git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" version = "0.34.3" +version = "0.34.3" [[deps.StatsFuns]] deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] @@ -2564,8 +2595,10 @@ version = "1.1.34+0" [[deps.XZ_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] git-tree-sha1 = "ac88fb95ae6447c8dda6a5503f3bafd496ae8632" +git-tree-sha1 = "ac88fb95ae6447c8dda6a5503f3bafd496ae8632" uuid = "ffd25f8a-64ca-5728-b0f7-c24cf3aae800" version = "5.4.6+0" +version = "5.4.6+0" [[deps.Xorg_libX11_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libxcb_jll", "Xorg_xtrans_jll"] @@ -2629,8 +2662,10 @@ version = "1.2.13+1" [[deps.Zstd_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] git-tree-sha1 = "e678132f07ddb5bfa46857f0d7620fb9be675d3b" +git-tree-sha1 = "e678132f07ddb5bfa46857f0d7620fb9be675d3b" uuid = "3161d3a3-bdf6-5164-811a-617609db77b4" version = "1.5.6+0" +version = "1.5.6+0" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] diff --git a/src/tl_inputs/from_actors.jl b/src/tl_inputs/from_actors.jl index 2fc8315..0d61940 100644 --- a/src/tl_inputs/from_actors.jl +++ b/src/tl_inputs/from_actors.jl @@ -124,7 +124,7 @@ function control_case_settings(::Type{TMLE.StatisticalATE}, treatments, data) end function addEstimands!(estimands, treatments, variables, data; positivity_constraint=0.) - freqs = TargeneCore.frequency_table(data, treatments) + freqs = TMLE.frequency_table(data, treatments) # This loop adds all ATE estimands where all other treatments than # the bQTL are fixed, at the order 1, this is the simple bQTL's ATE for setting in control_case_settings(TMLE.StatisticalATE, treatments, data) @@ -134,7 +134,7 @@ function addEstimands!(estimands, treatments, variables, data; positivity_constr treatment_confounders = NamedTuple{keys(setting)}([variables.confounders for key in keys(setting)]), outcome_extra_covariates = variables.covariates ) - if satisfies_positivity(Ψ, freqs; positivity_constraint=positivity_constraint) + if TMLE.satisfies_positivity(Ψ, freqs; positivity_constraint=positivity_constraint) update_estimands_from_outcomes!(estimands, Ψ, variables.targets) end end @@ -147,7 +147,7 @@ function addEstimands!(estimands, treatments, variables, data; positivity_constr treatment_confounders = NamedTuple{keys(setting)}([variables.confounders for key in keys(setting)]), outcome_extra_covariates = variables.covariates ) - if satisfies_positivity(Ψ, freqs; positivity_constraint=positivity_constraint) + if TMLE.satisfies_positivity(Ψ, freqs; positivity_constraint=positivity_constraint) update_estimands_from_outcomes!(estimands, Ψ, variables.targets) end end diff --git a/src/tl_inputs/from_param_files.jl b/src/tl_inputs/from_param_files.jl index c70ec29..2f3a79a 100644 --- a/src/tl_inputs/from_param_files.jl +++ b/src/tl_inputs/from_param_files.jl @@ -16,6 +16,7 @@ MismatchedCaseControlEncodingError() = NoRemainingParamsError(positivity_constraint) = ArgumentError(string("No parameter passed the given positivity constraint: ", positivity_constraint)) +MismatchedVariableError(variable) = ArgumentError(string("Each component of a ComposedEstimand should contain the same ", variable, " variables.")) function check_genotypes_encoding(val::NamedTuple, type) if !(typeof(val.case) <: type && typeof(val.control) <: type) @@ -27,17 +28,66 @@ check_genotypes_encoding(val::T, type) where T = T <: type || throw(MismatchedCaseControlEncodingError()) +get_treatments(Ψ) = keys(Ψ.treatment_values) + +function get_treatments(Ψ::ComposedEstimand) + treatments = get_treatments(first(Ψ.args)) + if length(Ψ.args) > 1 + for arg in Ψ.args[2:end] + get_treatments(arg) == treatments || throw(MismatchedVariableError("treatments")) + end + end + return treatments +end + +get_confounders(Ψ) = Tuple(Iterators.flatten((Tconf for Tconf ∈ Ψ.treatment_confounders))) + +function get_confounders(Ψ::ComposedEstimand) + confounders = get_confounders(first(Ψ.args)) + if length(Ψ.args) > 1 + for arg in Ψ.args[2:end] + get_confounders(arg) == confounders || throw(MismatchedVariableError("confounders")) + end + end + return confounders +end + +get_outcome_extra_covariates(Ψ) = Ψ.outcome_extra_covariates + +function get_outcome_extra_covariates(Ψ::ComposedEstimand) + outcome_extra_covariates = get_outcome_extra_covariates(first(Ψ.args)) + if length(Ψ.args) > 1 + for arg in Ψ.args[2:end] + get_outcome_extra_covariates(arg) == outcome_extra_covariates || throw(MismatchedVariableError("outcome extra covariates")) + end + end + return outcome_extra_covariates +end + +get_outcome(Ψ) = Ψ.outcome + +function get_outcome(Ψ::ComposedEstimand) + outcome = get_outcome(first(Ψ.args)) + if length(Ψ.args) > 1 + for arg in Ψ.args[2:end] + get_outcome(arg) == outcome || throw(MismatchedVariableError("outcome")) + end + end + return outcome +end + function get_variables(estimands, traits, pcs) genetic_variants = Set{Symbol}() others = Set{Symbol}() pcs = Set{Symbol}(filter(x -> x != :SAMPLE_ID, propertynames(pcs))) alltraits = Set{Symbol}(filter(x -> x != :SAMPLE_ID, propertynames(traits))) for Ψ in estimands - treatments = keys(Ψ.treatment_values) - confounders = Iterators.flatten((Tconf for Tconf ∈ Ψ.treatment_confounders)) + treatments = get_treatments(Ψ) + confounders = get_confounders(Ψ) + outcome_extra_covariates = get_outcome_extra_covariates(Ψ) push!( others, - Ψ.outcome_extra_covariates..., + outcome_extra_covariates..., confounders..., treatments... ) @@ -123,6 +173,8 @@ function adjust_parameter_sections(Ψ::T, variants_alleles, pcs) where T<:TMLE.E return T(outcome=Ψ.outcome, treatment_values=treatments, treatment_confounders=confounders, outcome_extra_covariates=Ψ.outcome_extra_covariates) end +adjust_parameter_sections(Ψ::ComposedEstimand, variants_alleles, pcs) = + ComposedEstimand(Ψ.f, Tuple(adjust_parameter_sections(arg, variants_alleles, pcs) for arg in Ψ.args)) function append_from_valid_estimands!( estimands::Vector{<:TMLE.Estimand}, @@ -136,29 +188,28 @@ function append_from_valid_estimands!( # Update treatment's and confounders's sections of Ψ Ψ = adjust_parameter_sections(Ψ, variants_alleles, variables.pcs) # Update frequency tables with current treatments - treatments = sorted_treatment_names(Ψ) + treatments = get_treatments(Ψ) if !haskey(frequency_tables, treatments) - frequency_tables[treatments] = TargeneCore.frequency_table(data, collect(treatments)) + frequency_tables[treatments] = TMLE.frequency_table(data, treatments) end # Check if parameter satisfies positivity - satisfies_positivity(Ψ, frequency_tables[treatments]; - positivity_constraint=positivity_constraint) || return - # Expand wildcard to all outcomes - if Ψ.outcome === :ALL - update_estimands_from_outcomes!(estimands, Ψ, variables.outcomes) - else - # Ψ.target || MissingVariableError(variable) - push!(estimands, Ψ) + if TMLE.satisfies_positivity(Ψ, frequency_tables[treatments]; positivity_constraint=positivity_constraint) + # Expand wildcard to all outcomes + if get_outcome(Ψ) === :ALL + update_estimands_from_outcomes!(estimands, Ψ, variables.outcomes) + else + push!(estimands, Ψ) + end end end function adjusted_estimands(estimands, variables, data; positivity_constraint=0.) final_estimands = TMLE.Estimand[] variants_alleles = Dict(v => Set(unique(skipmissing(data[!, v]))) for v in variables.genetic_variants) - freqency_tables = Dict() + frequency_tables = Dict() for Ψ in estimands # If the genotypes encoding is a string representation make sure they match the actual genotypes - append_from_valid_estimands!(final_estimands, freqency_tables, Ψ, data, variants_alleles, variables; positivity_constraint=positivity_constraint) + append_from_valid_estimands!(final_estimands, frequency_tables, Ψ, data, variants_alleles, variables; positivity_constraint=positivity_constraint) end length(final_estimands) > 0 || throw(NoRemainingParamsError(positivity_constraint)) diff --git a/src/tl_inputs/tl_inputs.jl b/src/tl_inputs/tl_inputs.jl index bed2f79..f5ddc27 100644 --- a/src/tl_inputs/tl_inputs.jl +++ b/src/tl_inputs/tl_inputs.jl @@ -64,6 +64,7 @@ NotAllVariantsFoundError(rsids) = ArgumentError(string("Some variants were not found in the genotype files: ", join(rsids, ", "))) NotBiAllelicOrUnphasedVariantError(rsid) = ArgumentError(string("Variant: ", rsid, " is not bi-allelic or not unphased.")) + """ bgen_files(snps, bgen_prefix) @@ -103,47 +104,8 @@ function call_genotypes(bgen_prefix::String, query_rsids::Set{<:AbstractString}, return genotypes end -sorted_treatment_names(Ψ) = tuple(sort(collect(keys(Ψ.treatment_values)))...) - -function setting_iterator(Ψ::TMLE.StatisticalIATE) - treatments = sorted_treatment_names(Ψ) - return ( - NamedTuple{treatments}(collect(Tval)) for - Tval in Iterators.product((values(Ψ.treatment_values[T]) for T in treatments)...) - ) -end - -function setting_iterator(Ψ::TMLE.StatisticalATE) - treatments = sorted_treatment_names(Ψ) - return ( - NamedTuple{treatments}([(Ψ.treatment_values[T][c]) for T in treatments]) - for c in (:case, :control) - ) -end - -function setting_iterator(Ψ::TMLE.StatisticalCM) - treatments = sorted_treatment_names(Ψ) - return (NamedTuple{treatments}(Ψ.treatment_values[T] for T in treatments), ) -end - -function satisfies_positivity(Ψ::TMLE.Estimand, freqs; positivity_constraint=0.01) - for base_setting in setting_iterator(Ψ) - if !haskey(freqs, base_setting) || freqs[base_setting] < positivity_constraint - return false - end - end - return true -end - -function frequency_table(data, treatments::AbstractVector) - treatments = sort(treatments) - freqs = Dict() - N = nrow(data) - for (key, group) in pairs(groupby(data, treatments; skipmissing=true)) - freqs[NamedTuple(key)] = nrow(group) / N - end - return freqs -end +TMLE.satisfies_positivity(Ψ::ComposedEstimand, freqs; positivity_constraint=0.01) = + all(TMLE.satisfies_positivity(arg, freqs; positivity_constraint=positivity_constraint) for arg in Ψ.args) read_txt_file(path::Nothing) = nothing read_txt_file(path) = CSV.read(path, DataFrame, header=false)[!, 1] @@ -164,15 +126,27 @@ function merge(traits, pcs, genotypes) ) end +estimand_with_new_outcome(Ψ::T, outcome) where T = T( + outcome=outcome, + treatment_values=Ψ.treatment_values, + treatment_confounders=Ψ.treatment_confounders, + outcome_extra_covariates=Ψ.outcome_extra_covariates +) + function update_estimands_from_outcomes!(estimands, Ψ::T, outcomes) where T for outcome in outcomes push!( - estimands, - T( - outcome=outcome, - treatment_values=Ψ.treatment_values, - treatment_confounders=Ψ.treatment_confounders, - outcome_extra_covariates=Ψ.outcome_extra_covariates) + estimands, + estimand_with_new_outcome(Ψ, outcome) + ) + end +end + +function update_estimands_from_outcomes!(estimands, Ψ::ComposedEstimand, outcomes) + for outcome in outcomes + push!( + estimands, + ComposedEstimand(Ψ.f, Tuple(estimand_with_new_outcome(arg, outcome) for arg in Ψ.args)) ) end end diff --git a/test/tl_inputs/from_param_files.jl b/test/tl_inputs/from_param_files.jl index 2215245..8e9c153 100644 --- a/test/tl_inputs/from_param_files.jl +++ b/test/tl_inputs/from_param_files.jl @@ -24,6 +24,40 @@ include(joinpath(TESTDIR, "tl_inputs", "test_utils.jl")) pcs = TargeneCore.read_csv_file(joinpath(TESTDIR, "data", "pcs.csv")) # extraW, extraT, extraC are parsed from all param_files estimands = make_estimands_configuration().estimands + # get_treatments, get_outcome, ... + ## Simple Estimand + Ψ = estimands[1] + @test TargeneCore.get_outcome(Ψ) == :ALL + @test TargeneCore.get_treatments(Ψ) == keys(Ψ.treatment_values) + @test TargeneCore.get_confounders(Ψ) == () + @test TargeneCore.get_outcome_extra_covariates(Ψ) == () + ## ComposedEstimand + Ψ = estimands[5] + @test TargeneCore.get_outcome(Ψ) == :ALL + @test TargeneCore.get_treatments(Ψ) == keys(Ψ.args[1].treatment_values) + @test TargeneCore.get_confounders(Ψ) == () + @test TargeneCore.get_outcome_extra_covariates(Ψ) == (Symbol("22001"), ) + ## Bad ComposedEstimand + Ψ = ComposedEstimand( + TMLE.joint_estimand, ( + CM( + outcome = "Y1", + treatment_values = (RSID_3 = "GG", RSID_198 = "AG"), + treatment_confounders = (RSID_3 = [], RSID_198 = []), + outcome_extra_covariates = [22001] + ), + CM( + outcome = "Y2", + treatment_values = (RSID_2 = "AA", RSID_198 = "AG"), + treatment_confounders = (RSID_2 = [:PC1], RSID_198 = []), + outcome_extra_covariates = [] + )) + ) + @test_throws ArgumentError TargeneCore.get_outcome(Ψ) == :ALL + @test_throws ArgumentError TargeneCore.get_treatments(Ψ) + @test_throws ArgumentError TargeneCore.get_confounders(Ψ) + @test_throws ArgumentError TargeneCore.get_outcome_extra_covariates(Ψ) + # get_variables variables = TargeneCore.get_variables(estimands, traits, pcs) @test variables.genetic_variants == Set([:RSID_198, :RSID_2]) @test variables.outcomes == Set([:BINARY_1, :CONTINUOUS_2, :CONTINUOUS_1, :BINARY_2]) @@ -38,8 +72,9 @@ end ) pcs = Set([:PC1, :PC2]) variants_alleles = Dict(:RSID_198 => Set(genotypes.RSID_198)) - # AG is not in the genotypes but GA is - Ψ = make_estimands_configuration().estimands[4] + estimands = make_estimands_configuration().estimands + # RS198 AG is not in the genotypes but GA is + Ψ = estimands[4] @test Ψ.treatment_values.RSID_198 == (case="AG", control="AA") new_Ψ = TargeneCore.adjust_parameter_sections(Ψ, variants_alleles, pcs) @test new_Ψ.outcome == Ψ.outcome @@ -50,6 +85,19 @@ end RSID_2 = (case = "AA", control = "GG") ) + # ComnposedEstimand + Ψ = estimands[5] + @test Ψ.args[1].treatment_values == (RSID_198 = "AG", RSID_2 = "GG") + @test Ψ.args[2].treatment_values == (RSID_198 = "AG", RSID_2 = "AA") + new_Ψ = TargeneCore.adjust_parameter_sections(Ψ, variants_alleles, pcs) + for index in 1:length(Ψ.args) + @test new_Ψ.args[index].outcome == Ψ.args[index].outcome + @test new_Ψ.args[index].outcome_extra_covariates == (Symbol(22001),) + @test new_Ψ.args[index].treatment_confounders == (RSID_198 = (:PC1, :PC2), RSID_2 = (:PC1, :PC2),) + end + @test new_Ψ.args[1].treatment_values == (RSID_198 = "GA", RSID_2 = "GG") + @test new_Ψ.args[2].treatment_values == (RSID_198 = "GA", RSID_2 = "AA") + # If the allele is not present variants_alleles = Dict(:RSID_198 => Set(["AA"])) @test_throws TargeneCore.AbsentAlleleError("RSID_198", "AG") TargeneCore.adjust_parameter_sections(Ψ, variants_alleles, pcs) @@ -95,8 +143,8 @@ end ## Estimands file: output_estimands = deserialize("final.estimands.jls").estimands - # There are 5 initial estimands containing a * - # Those are duplicated for each of the 4 targets. + # There are 5 initial estimands containing a :ALL + # Those are duplicated for each of the 4 outcomes. @test length(output_estimands) == 20 # In all cases the PCs are appended to the confounders. for Ψ ∈ output_estimands @@ -120,10 +168,11 @@ end @test Ψ.outcome_extra_covariates == (Symbol("22001"),) # Input Estimand 5: GA is corrected to AG to match the data - elseif Ψ isa TMLE.StatisticalCM && Ψ.treatment_values == (RSID_198 = "AG", RSID_2 = "GG") - @test Ψ.treatment_confounders == (RSID_198 = (:PC1, :PC2), RSID_2 = (:PC1, :PC2)) - @test Ψ.outcome_extra_covariates == (Symbol("22001"),) - + elseif Ψ isa TMLE.ComposedEstimand + @test Ψ.args[1].treatment_values == (RSID_198 = "AG", RSID_2 = "GG") + @test Ψ.args[2].treatment_values == (RSID_198 = "AG", RSID_2 = "AA") + @test Ψ.args[1].treatment_confounders == Ψ.args[2].treatment_confounders == (RSID_198 = (:PC1, :PC2), RSID_2 = (:PC1, :PC2)) + @test Ψ.args[1].outcome_extra_covariates == Ψ.args[2].outcome_extra_covariates == (Symbol("22001"),) else throw(AssertionError(string("Which input did this output come from: ", Ψ))) end @@ -142,7 +191,7 @@ end tl_inputs(parsed_args) # The IATES are the most sensitives outestimands = deserialize("final.estimands.jls").estimands - @test all(Ψ isa Union{TMLE.StatisticalCM, TMLE.StatisticalATE} for Ψ in outestimands) + @test all(Ψ isa Union{TMLE.StatisticalCM, TMLE.StatisticalATE, ComposedEstimand} for Ψ in outestimands) @test size(outestimands, 1) == 16 cleanup() diff --git a/test/tl_inputs/test_utils.jl b/test/tl_inputs/test_utils.jl index 038e709..8783cda 100644 --- a/test/tl_inputs/test_utils.jl +++ b/test/tl_inputs/test_utils.jl @@ -6,7 +6,6 @@ function cleanup(;prefix="final.") end end - function make_estimands_configuration() estimands = [ IATE( @@ -32,11 +31,20 @@ function make_estimands_configuration() treatment_confounders = (RSID_2 = [], RSID_198 = []), outcome_extra_covariates = [22001] ), - CM( - outcome = "ALL", - treatment_values = (RSID_2 = "GG", RSID_198 = "GA"), - treatment_confounders = (RSID_2 = [], RSID_198 = []), - outcome_extra_covariates = [22001] + ComposedEstimand( + TMLE.joint_estimand, ( + CM( + outcome = "ALL", + treatment_values = (RSID_2 = "GG", RSID_198 = "AG"), + treatment_confounders = (RSID_2 = [], RSID_198 = []), + outcome_extra_covariates = [22001] + ), + CM( + outcome = "ALL", + treatment_values = (RSID_2 = "AA", RSID_198 = "AG"), + treatment_confounders = (RSID_2 = [], RSID_198 = []), + outcome_extra_covariates = [22001] + )) ) ] return Configuration(estimands=estimands) diff --git a/test/tl_inputs/tl_inputs.jl b/test/tl_inputs/tl_inputs.jl index efb9541..4ded885 100644 --- a/test/tl_inputs/tl_inputs.jl +++ b/test/tl_inputs/tl_inputs.jl @@ -83,103 +83,6 @@ end @test_throws TargeneCore.NotAllVariantsFoundError(variants) TargeneCore.call_genotypes(bgen_dir, variants, 0.95;) end - -@testset "Test positivity_constraint" begin - data = DataFrame( - A = [1, 1, 0, 1, 0, 2, 2, 1], - B = ["AC", "CC", "AA", "AA", "AA", "AA", "AA", "AA"] - ) - ## One variable - freqs = TargeneCore.frequency_table(data, [:A]) - @test freqs == Dict( - (A = 0,) => 0.25, - (A = 2,) => 0.25, - (A = 1,) => 0.5 - ) - Ψ = CM( - outcome = :toto, - treatment_values = (A=1,), - treatment_confounders = (A=[],) - ) - @test TargeneCore.setting_iterator(Ψ) == ((A = 1,),) - @test TargeneCore.satisfies_positivity(Ψ, freqs, positivity_constraint=0.4) == true - @test TargeneCore.satisfies_positivity(Ψ, freqs, positivity_constraint=0.6) == false - - Ψ = ATE( - outcome = :toto, - treatment_values= (A= (case=1, control=0),), - treatment_confounders = (A=[],) - ) - @test collect(TargeneCore.setting_iterator(Ψ)) == [(A = 1,), (A = 0,)] - @test TargeneCore.satisfies_positivity(Ψ, freqs, positivity_constraint=0.2) == true - @test TargeneCore.satisfies_positivity(Ψ, freqs, positivity_constraint=0.3) == false - - ## Two variables - # Treatments are sorted: [:B, :A] -> [:A, :B] - freqs = TargeneCore.frequency_table(data, [:B, :A]) - @test freqs == Dict( - (A = 1, B = "CC") => 0.125, - (A = 1, B = "AA") => 0.25, - (A = 0, B = "AA") => 0.25, - (A = 1, B = "AC") => 0.125, - (A = 2, B = "AA") => 0.25 - ) - - Ψ = CM( - outcome = :toto, - treatment_values = (B = "CC", A = 1), - treatment_confounders = (B = [], A = []) - ) - @test TargeneCore.setting_iterator(Ψ) == ((A = 1, B = "CC"),) - @test TargeneCore.satisfies_positivity(Ψ, freqs, positivity_constraint=0.1) == true - @test TargeneCore.satisfies_positivity(Ψ, freqs, positivity_constraint=0.15) == false - - Ψ = ATE( - outcome = :toto, - treatment_values = (B=(case="AA", control="AC"), A=(case=1, control=1),), - treatment_confounders = (B = (), A = (),) - ) - @test collect(TargeneCore.setting_iterator(Ψ)) == [(A = 1, B = "AA"), (A = 1, B = "AC")] - @test TargeneCore.satisfies_positivity(Ψ, freqs, positivity_constraint=0.1) == true - @test TargeneCore.satisfies_positivity(Ψ, freqs, positivity_constraint=0.2) == false - - Ψ = IATE( - outcome = :toto, - treatment_values = (B=(case="AC", control="AA"), A=(case=1, control=0),), - treatment_confounders = (B=(), A=()), - ) - @test collect(TargeneCore.setting_iterator(Ψ)) == [ - (A = 1, B = "AC") (A = 1, B = "AA") - (A = 0, B = "AC") (A = 0, B = "AA")] - @test TargeneCore.satisfies_positivity(Ψ, freqs, positivity_constraint=1.) == false - freqs = Dict( - (A = 1, B = "CC") => 0.125, - (A = 1, B = "AA") => 0.25, - (A = 0, B = "AA") => 0.25, - (A = 0, B = "AC") => 0.25, - (A = 1, B = "AC") => 0.125, - (A = 2, B = "AA") => 0.25 - ) - @test TargeneCore.satisfies_positivity(Ψ, freqs, positivity_constraint=0.3) == false - @test TargeneCore.satisfies_positivity(Ψ, freqs, positivity_constraint=0.1) == true - - Ψ = IATE( - outcome = :toto, - treatment_values = (B=(case="AC", control="AA"), A=(case=1, control=0), C=(control=0, case=2)), - treatment_confounders = (B=(), A=(), C=()) - ) - expected_settings = Set([ - (A = 1, B = "AC", C = 0), - (A = 0, B = "AC", C = 0), - (A = 1, B = "AA", C = 0), - (A = 0, B = "AA", C = 0), - (A = 1, B = "AC", C = 2), - (A = 0, B = "AC", C = 2), - (A = 1, B = "AA", C = 2), - (A = 0, B = "AA", C = 2)]) - @test expected_settings == Set(TargeneCore.setting_iterator(Ψ)) -end - end true \ No newline at end of file