Skip to content

Commit

Permalink
Revert "revert to stable commit"
Browse files Browse the repository at this point in the history
This reverts commit 5c30fc3.
  • Loading branch information
joshua-slaughter committed Jun 6, 2024
1 parent f971341 commit a53dd4e
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 162 deletions.
6 changes: 3 additions & 3 deletions src/tl_inputs/from_actors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ function control_case_settings(::Type{TMLE.StatisticalATE}, treatments, data)
end

function addEstimands!(estimands, treatments, variables, data; positivity_constraint=0.)
freqs = TMLE.frequency_table(data, treatments)
freqs = TargeneCore.frequency_table(data, treatments)
# This loop adds all ATE estimands where all other treatments than
# the bQTL are fixed, at the order 1, this is the simple bQTL's ATE
for setting in control_case_settings(TMLE.StatisticalATE, treatments, data)
Expand All @@ -134,7 +134,7 @@ function addEstimands!(estimands, treatments, variables, data; positivity_constr
treatment_confounders = NamedTuple{keys(setting)}([variables.confounders for key in keys(setting)]),
outcome_extra_covariates = variables.covariates
)
if TMLE.satisfies_positivity(Ψ, freqs; positivity_constraint=positivity_constraint)
if satisfies_positivity(Ψ, freqs; positivity_constraint=positivity_constraint)
update_estimands_from_outcomes!(estimands, Ψ, variables.targets)
end
end
Expand All @@ -147,7 +147,7 @@ function addEstimands!(estimands, treatments, variables, data; positivity_constr
treatment_confounders = NamedTuple{keys(setting)}([variables.confounders for key in keys(setting)]),
outcome_extra_covariates = variables.covariates
)
if TMLE.satisfies_positivity(Ψ, freqs; positivity_constraint=positivity_constraint)
if satisfies_positivity(Ψ, freqs; positivity_constraint=positivity_constraint)
update_estimands_from_outcomes!(estimands, Ψ, variables.targets)
end
end
Expand Down
81 changes: 15 additions & 66 deletions src/tl_inputs/from_param_files.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ MismatchedCaseControlEncodingError() =

NoRemainingParamsError(positivity_constraint) = ArgumentError(string("No parameter passed the given positivity constraint: ", positivity_constraint))

MismatchedVariableError(variable) = ArgumentError(string("Each component of a ComposedEstimand should contain the same ", variable, " variables."))

function check_genotypes_encoding(val::NamedTuple, type)
if !(typeof(val.case) <: type && typeof(val.control) <: type)
Expand All @@ -28,66 +27,17 @@ check_genotypes_encoding(val::T, type) where T =
T <: type || throw(MismatchedCaseControlEncodingError())


get_treatments(Ψ) = keys.treatment_values)

function get_treatments::ComposedEstimand)
treatments = get_treatments(first.args))
if length.args) > 1
for arg in Ψ.args[2:end]
get_treatments(arg) == treatments || throw(MismatchedVariableError("treatments"))
end
end
return treatments
end

get_confounders(Ψ) = Tuple(Iterators.flatten((Tconf for Tconf Ψ.treatment_confounders)))

function get_confounders::ComposedEstimand)
confounders = get_confounders(first.args))
if length.args) > 1
for arg in Ψ.args[2:end]
get_confounders(arg) == confounders || throw(MismatchedVariableError("confounders"))
end
end
return confounders
end

get_outcome_extra_covariates(Ψ) = Ψ.outcome_extra_covariates

function get_outcome_extra_covariates::ComposedEstimand)
outcome_extra_covariates = get_outcome_extra_covariates(first.args))
if length.args) > 1
for arg in Ψ.args[2:end]
get_outcome_extra_covariates(arg) == outcome_extra_covariates || throw(MismatchedVariableError("outcome extra covariates"))
end
end
return outcome_extra_covariates
end

get_outcome(Ψ) = Ψ.outcome

function get_outcome::ComposedEstimand)
outcome = get_outcome(first.args))
if length.args) > 1
for arg in Ψ.args[2:end]
get_outcome(arg) == outcome || throw(MismatchedVariableError("outcome"))
end
end
return outcome
end

function get_variables(estimands, traits, pcs)
genetic_variants = Set{Symbol}()
others = Set{Symbol}()
pcs = Set{Symbol}(filter(x -> x != :SAMPLE_ID, propertynames(pcs)))
alltraits = Set{Symbol}(filter(x -> x != :SAMPLE_ID, propertynames(traits)))
for Ψ in estimands
treatments = get_treatments(Ψ)
confounders = get_confounders(Ψ)
outcome_extra_covariates = get_outcome_extra_covariates(Ψ)
treatments = keys.treatment_values)
confounders = Iterators.flatten((Tconf for Tconf Ψ.treatment_confounders))
push!(
others,
outcome_extra_covariates...,
Ψ.outcome_extra_covariates...,
confounders...,
treatments...
)
Expand Down Expand Up @@ -173,8 +123,6 @@ function adjust_parameter_sections(Ψ::T, variants_alleles, pcs) where T<:TMLE.E
return T(outcome=Ψ.outcome, treatment_values=treatments, treatment_confounders=confounders, outcome_extra_covariates=Ψ.outcome_extra_covariates)
end

adjust_parameter_sections::ComposedEstimand, variants_alleles, pcs) =
ComposedEstimand.f, Tuple(adjust_parameter_sections(arg, variants_alleles, pcs) for arg in Ψ.args))

function append_from_valid_estimands!(
estimands::Vector{<:TMLE.Estimand},
Expand All @@ -188,28 +136,29 @@ function append_from_valid_estimands!(
# Update treatment's and confounders's sections of Ψ
Ψ = adjust_parameter_sections(Ψ, variants_alleles, variables.pcs)
# Update frequency tables with current treatments
treatments = get_treatments(Ψ)
treatments = sorted_treatment_names(Ψ)
if !haskey(frequency_tables, treatments)
frequency_tables[treatments] = TMLE.frequency_table(data, treatments)
frequency_tables[treatments] = TargeneCore.frequency_table(data, collect(treatments))
end
# Check if parameter satisfies positivity
if TMLE.satisfies_positivity(Ψ, frequency_tables[treatments]; positivity_constraint=positivity_constraint)
# Expand wildcard to all outcomes
if get_outcome(Ψ) === :ALL
update_estimands_from_outcomes!(estimands, Ψ, variables.outcomes)
else
push!(estimands, Ψ)
end
satisfies_positivity(Ψ, frequency_tables[treatments];
positivity_constraint=positivity_constraint) || return
# Expand wildcard to all outcomes
if Ψ.outcome === :ALL
update_estimands_from_outcomes!(estimands, Ψ, variables.outcomes)
else
# Ψ.target || MissingVariableError(variable)
push!(estimands, Ψ)
end
end

function adjusted_estimands(estimands, variables, data; positivity_constraint=0.)
final_estimands = TMLE.Estimand[]
variants_alleles = Dict(v => Set(unique(skipmissing(data[!, v]))) for v in variables.genetic_variants)
frequency_tables = Dict()
freqency_tables = Dict()
for Ψ in estimands
# If the genotypes encoding is a string representation make sure they match the actual genotypes
append_from_valid_estimands!(final_estimands, frequency_tables, Ψ, data, variants_alleles, variables; positivity_constraint=positivity_constraint)
append_from_valid_estimands!(final_estimands, freqency_tables, Ψ, data, variants_alleles, variables; positivity_constraint=positivity_constraint)
end

length(final_estimands) > 0 || throw(NoRemainingParamsError(positivity_constraint))
Expand Down
68 changes: 47 additions & 21 deletions src/tl_inputs/tl_inputs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ NotAllVariantsFoundError(rsids) =
ArgumentError(string("Some variants were not found in the genotype files: ", join(rsids, ", ")))

NotBiAllelicOrUnphasedVariantError(rsid) = ArgumentError(string("Variant: ", rsid, " is not bi-allelic or not unphased."))

"""
bgen_files(snps, bgen_prefix)
Expand Down Expand Up @@ -104,8 +103,47 @@ function call_genotypes(bgen_prefix::String, query_rsids::Set{<:AbstractString},
return genotypes
end

TMLE.satisfies_positivity::ComposedEstimand, freqs; positivity_constraint=0.01) =
all(TMLE.satisfies_positivity(arg, freqs; positivity_constraint=positivity_constraint) for arg in Ψ.args)
sorted_treatment_names(Ψ) = tuple(sort(collect(keys.treatment_values)))...)

function setting_iterator::TMLE.StatisticalIATE)
treatments = sorted_treatment_names(Ψ)
return (
NamedTuple{treatments}(collect(Tval)) for
Tval in Iterators.product((values.treatment_values[T]) for T in treatments)...)
)
end

function setting_iterator::TMLE.StatisticalATE)
treatments = sorted_treatment_names(Ψ)
return (
NamedTuple{treatments}([(Ψ.treatment_values[T][c]) for T in treatments])
for c in (:case, :control)
)
end

function setting_iterator::TMLE.StatisticalCM)
treatments = sorted_treatment_names(Ψ)
return (NamedTuple{treatments}.treatment_values[T] for T in treatments), )
end

function satisfies_positivity::TMLE.Estimand, freqs; positivity_constraint=0.01)
for base_setting in setting_iterator(Ψ)
if !haskey(freqs, base_setting) || freqs[base_setting] < positivity_constraint
return false
end
end
return true
end

function frequency_table(data, treatments::AbstractVector)
treatments = sort(treatments)
freqs = Dict()
N = nrow(data)
for (key, group) in pairs(groupby(data, treatments; skipmissing=true))
freqs[NamedTuple(key)] = nrow(group) / N
end
return freqs
end

read_txt_file(path::Nothing) = nothing
read_txt_file(path) = CSV.read(path, DataFrame, header=false)[!, 1]
Expand All @@ -126,27 +164,15 @@ function merge(traits, pcs, genotypes)
)
end

estimand_with_new_outcome::T, outcome) where T = T(
outcome=outcome,
treatment_values=Ψ.treatment_values,
treatment_confounders=Ψ.treatment_confounders,
outcome_extra_covariates=Ψ.outcome_extra_covariates
)

function update_estimands_from_outcomes!(estimands, Ψ::T, outcomes) where T
for outcome in outcomes
push!(
estimands,
estimand_with_new_outcome(Ψ, outcome)
)
end
end

function update_estimands_from_outcomes!(estimands, Ψ::ComposedEstimand, outcomes)
for outcome in outcomes
push!(
estimands,
ComposedEstimand.f, Tuple(estimand_with_new_outcome(arg, outcome) for arg in Ψ.args))
estimands,
T(
outcome=outcome,
treatment_values=Ψ.treatment_values,
treatment_confounders=Ψ.treatment_confounders,
outcome_extra_covariates=Ψ.outcome_extra_covariates)
)
end
end
Expand Down
67 changes: 9 additions & 58 deletions test/tl_inputs/from_param_files.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,40 +24,6 @@ include(joinpath(TESTDIR, "tl_inputs", "test_utils.jl"))
pcs = TargeneCore.read_csv_file(joinpath(TESTDIR, "data", "pcs.csv"))
# extraW, extraT, extraC are parsed from all param_files
estimands = make_estimands_configuration().estimands
# get_treatments, get_outcome, ...
## Simple Estimand
Ψ = estimands[1]
@test TargeneCore.get_outcome(Ψ) == :ALL
@test TargeneCore.get_treatments(Ψ) == keys.treatment_values)
@test TargeneCore.get_confounders(Ψ) == ()
@test TargeneCore.get_outcome_extra_covariates(Ψ) == ()
## ComposedEstimand
Ψ = estimands[5]
@test TargeneCore.get_outcome(Ψ) == :ALL
@test TargeneCore.get_treatments(Ψ) == keys.args[1].treatment_values)
@test TargeneCore.get_confounders(Ψ) == ()
@test TargeneCore.get_outcome_extra_covariates(Ψ) == (Symbol("22001"), )
## Bad ComposedEstimand
Ψ = ComposedEstimand(
TMLE.joint_estimand, (
CM(
outcome = "Y1",
treatment_values = (RSID_3 = "GG", RSID_198 = "AG"),
treatment_confounders = (RSID_3 = [], RSID_198 = []),
outcome_extra_covariates = [22001]
),
CM(
outcome = "Y2",
treatment_values = (RSID_2 = "AA", RSID_198 = "AG"),
treatment_confounders = (RSID_2 = [:PC1], RSID_198 = []),
outcome_extra_covariates = []
))
)
@test_throws ArgumentError TargeneCore.get_outcome(Ψ) == :ALL
@test_throws ArgumentError TargeneCore.get_treatments(Ψ)
@test_throws ArgumentError TargeneCore.get_confounders(Ψ)
@test_throws ArgumentError TargeneCore.get_outcome_extra_covariates(Ψ)
# get_variables
variables = TargeneCore.get_variables(estimands, traits, pcs)
@test variables.genetic_variants == Set([:RSID_198, :RSID_2])
@test variables.outcomes == Set([:BINARY_1, :CONTINUOUS_2, :CONTINUOUS_1, :BINARY_2])
Expand All @@ -72,9 +38,8 @@ end
)
pcs = Set([:PC1, :PC2])
variants_alleles = Dict(:RSID_198 => Set(genotypes.RSID_198))
estimands = make_estimands_configuration().estimands
# RS198 AG is not in the genotypes but GA is
Ψ = estimands[4]
# AG is not in the genotypes but GA is
Ψ = make_estimands_configuration().estimands[4]
@test Ψ.treatment_values.RSID_198 == (case="AG", control="AA")
new_Ψ = TargeneCore.adjust_parameter_sections(Ψ, variants_alleles, pcs)
@test new_Ψ.outcome == Ψ.outcome
Expand All @@ -85,19 +50,6 @@ end
RSID_2 = (case = "AA", control = "GG")
)

# ComnposedEstimand
Ψ = estimands[5]
@test Ψ.args[1].treatment_values == (RSID_198 = "AG", RSID_2 = "GG")
@test Ψ.args[2].treatment_values == (RSID_198 = "AG", RSID_2 = "AA")
new_Ψ = TargeneCore.adjust_parameter_sections(Ψ, variants_alleles, pcs)
for index in 1:length.args)
@test new_Ψ.args[index].outcome == Ψ.args[index].outcome
@test new_Ψ.args[index].outcome_extra_covariates == (Symbol(22001),)
@test new_Ψ.args[index].treatment_confounders == (RSID_198 = (:PC1, :PC2), RSID_2 = (:PC1, :PC2),)
end
@test new_Ψ.args[1].treatment_values == (RSID_198 = "GA", RSID_2 = "GG")
@test new_Ψ.args[2].treatment_values == (RSID_198 = "GA", RSID_2 = "AA")

# If the allele is not present
variants_alleles = Dict(:RSID_198 => Set(["AA"]))
@test_throws TargeneCore.AbsentAlleleError("RSID_198", "AG") TargeneCore.adjust_parameter_sections(Ψ, variants_alleles, pcs)
Expand Down Expand Up @@ -143,8 +95,8 @@ end

## Estimands file:
output_estimands = deserialize("final.estimands.jls").estimands
# There are 5 initial estimands containing a :ALL
# Those are duplicated for each of the 4 outcomes.
# There are 5 initial estimands containing a *
# Those are duplicated for each of the 4 targets.
@test length(output_estimands) == 20
# In all cases the PCs are appended to the confounders.
for Ψ output_estimands
Expand All @@ -168,11 +120,10 @@ end
@test Ψ.outcome_extra_covariates == (Symbol("22001"),)

# Input Estimand 5: GA is corrected to AG to match the data
elseif Ψ isa TMLE.ComposedEstimand
@test Ψ.args[1].treatment_values == (RSID_198 = "AG", RSID_2 = "GG")
@test Ψ.args[2].treatment_values == (RSID_198 = "AG", RSID_2 = "AA")
@test Ψ.args[1].treatment_confounders == Ψ.args[2].treatment_confounders == (RSID_198 = (:PC1, :PC2), RSID_2 = (:PC1, :PC2))
@test Ψ.args[1].outcome_extra_covariates == Ψ.args[2].outcome_extra_covariates == (Symbol("22001"),)
elseif Ψ isa TMLE.StatisticalCM && Ψ.treatment_values == (RSID_198 = "AG", RSID_2 = "GG")
@test Ψ.treatment_confounders == (RSID_198 = (:PC1, :PC2), RSID_2 = (:PC1, :PC2))
@test Ψ.outcome_extra_covariates == (Symbol("22001"),)

else
throw(AssertionError(string("Which input did this output come from: ", Ψ)))
end
Expand All @@ -191,7 +142,7 @@ end
tl_inputs(parsed_args)
# The IATES are the most sensitives
outestimands = deserialize("final.estimands.jls").estimands
@test allisa Union{TMLE.StatisticalCM, TMLE.StatisticalATE, ComposedEstimand} for Ψ in outestimands)
@test allisa Union{TMLE.StatisticalCM, TMLE.StatisticalATE} for Ψ in outestimands)
@test size(outestimands, 1) == 16

cleanup()
Expand Down
20 changes: 6 additions & 14 deletions test/tl_inputs/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ function cleanup(;prefix="final.")
end
end


function make_estimands_configuration()
estimands = [
IATE(
Expand All @@ -31,20 +32,11 @@ function make_estimands_configuration()
treatment_confounders = (RSID_2 = [], RSID_198 = []),
outcome_extra_covariates = [22001]
),
ComposedEstimand(
TMLE.joint_estimand, (
CM(
outcome = "ALL",
treatment_values = (RSID_2 = "GG", RSID_198 = "AG"),
treatment_confounders = (RSID_2 = [], RSID_198 = []),
outcome_extra_covariates = [22001]
),
CM(
outcome = "ALL",
treatment_values = (RSID_2 = "AA", RSID_198 = "AG"),
treatment_confounders = (RSID_2 = [], RSID_198 = []),
outcome_extra_covariates = [22001]
))
CM(
outcome = "ALL",
treatment_values = (RSID_2 = "GG", RSID_198 = "GA"),
treatment_confounders = (RSID_2 = [], RSID_198 = []),
outcome_extra_covariates = [22001]
)
]
return Configuration(estimands=estimands)
Expand Down
Loading

0 comments on commit a53dd4e

Please sign in to comment.