Skip to content

Commit

Permalink
add a no gene atlas variants mode for end to end testing
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Jul 18, 2024
1 parent 42d3cd4 commit d2dd7c0
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 38 deletions.
2 changes: 1 addition & 1 deletion src/Simulations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ include(joinpath("density_estimation", "density_estimation.jl"))
include(joinpath("samplers", "null_sampler.jl"))
include(joinpath("samplers", "density_estimate_sampler.jl"))

include(joinpath("inputs_from_gene_atlas.jl"))
include(joinpath("realistic_simulation_inputs.jl"))
include("estimation.jl")
include("cli.jl")

Expand Down
27 changes: 20 additions & 7 deletions src/cli.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ function cli_settings()
action = :command
help = "Estimate a conditional density."

"simulation-inputs-from-ga"
"realistic-simulation-inputs"
action = :command
help = "Generate simulation inputs from geneATLAS."
help = "Generate realistic simulation inputs optionally using geneATLAS hits."

"analyse"
action = :command
Expand Down Expand Up @@ -157,7 +157,7 @@ function cli_settings()

end

@add_arg_table! s["simulation-inputs-from-ga"] begin
@add_arg_table! s["realistic-simulation-inputs"] begin
"estimands-prefix"
arg_type = String
help = "A prefix to serialized TMLE.Configuration (accepted formats: .json | .yaml | .jls)"
Expand All @@ -173,6 +173,11 @@ function cli_settings()
"pcs"
arg_type = String
help = "The dataset of principal components."

"--sample-gene-atlas-hits"
arg_type = Bool
default = true
help = "Whether to sample additional variants from the geneATLAS."

"--ga-download-dir"
arg_type = String
Expand Down Expand Up @@ -233,7 +238,11 @@ function cli_settings()
arg_type = Int
default = 10
help = "Estimands are further split in files of `batchsize`"


"--variants-regex"
arg_type = String
default = "^(rs[0-9]*|Affx)"
help = "Regular expression to identify genetic variants from estimands."
end

return s
Expand Down Expand Up @@ -262,12 +271,13 @@ function julia_main()::Cint
)
elseif cmd == "aggregate"
save_aggregated_df_results(cmd_settings["input-prefix"], cmd_settings["out"])
elseif cmd == "simulation-inputs-from-ga"
simulation_inputs_from_gene_atlas(
elseif cmd == "realistic-simulation-inputs"
realistic_simulation_inputs(
cmd_settings["estimands-prefix"],
cmd_settings["bgen-prefix"],
cmd_settings["traits"],
cmd_settings["pcs"];
sample_gene_atlas_hits=cmd_settings["sample-gene-atlas-hits"],
gene_atlas_dir=cmd_settings["ga-download-dir"],
remove_ga_data=cmd_settings["remove-ga-data"],
trait_table_path=cmd_settings["ga-trait-table"],
Expand All @@ -279,7 +289,8 @@ function julia_main()::Cint
verbosity=cmd_settings["verbosity"],
output_prefix=cmd_settings["output-prefix"],
batchsize=cmd_settings["batchsize"],
max_variants=cmd_settings["max-variants"]
max_variants=cmd_settings["max-variants"],
variants_regex=cmd_settings["variants-regex"]
)
elseif cmd == "density-estimation-inputs"
density_estimation_inputs(
Expand All @@ -306,6 +317,8 @@ function julia_main()::Cint
dataset_file=cmd_settings["dataset-file"],
density_estimates_prefix=cmd_settings["density-estimates-prefix"],
)
else
throw(ArgumentError(string("No function matching command:", cmd)))
end

return 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ function download_variants_info(outdir)
end
end

function get_trait_to_variants_from_estimands(estimands; regex=r"^(rs[0-9]*|Affx)")
function get_trait_to_variants_from_estimands(estimands; variants_regex=r"^(rs[0-9]*|Affx)")
trait_to_variants = Dict()
for Ψ in estimands
outcome = string(TargeneCore.get_outcome(Ψ))
variants = filter(x -> occursin(regex, x), string.(TargeneCore.get_treatments(Ψ)))
variants = filter(x -> occursin(variants_regex, x), string.(TargeneCore.get_treatments(Ψ)))
if haskey(trait_to_variants, outcome)
union!(trait_to_variants[outcome], variants)
else
Expand Down Expand Up @@ -174,7 +174,9 @@ function group_by_outcome(estimands)
return groups
end

function get_trait_to_variants(estimands;
function get_trait_to_variants(estimands;
variants_regex=r"^(rs[0-9]*|Affx)",
sample_gene_atlas_hits=true,
verbosity=0,
gene_atlas_dir="gene_atlas_data",
remove_ga_data=true,
Expand All @@ -187,19 +189,21 @@ function get_trait_to_variants(estimands;
)
verbosity > 0 && @info("Retrieve significant variants for each outcome.")
# Retrieve traits and variants from estimands
trait_to_variants = get_trait_to_variants_from_estimands(estimands)
# Retrieve Trait to geneAtlas key map
trait_key_map = get_trait_key_map(keys(trait_to_variants), trait_table_path=trait_table_path)
# Update variant set for each trait using geneAtlas summary statistics
update_trait_to_variants_from_gene_atlas!(trait_to_variants, trait_key_map;
gene_atlas_dir=gene_atlas_dir,
remove_ga_data=remove_ga_data,
maf_threshold=maf_threshold,
pvalue_threshold=pvalue_threshold,
distance_threshold=distance_threshold,
max_variants=max_variants,
bgen_prefix=bgen_prefix
)
trait_to_variants = get_trait_to_variants_from_estimands(estimands;variants_regex=variants_regex)
if sample_gene_atlas_hits
# Retrieve Trait to geneAtlas key map
trait_key_map = get_trait_key_map(keys(trait_to_variants), trait_table_path=trait_table_path)
# Update variant set for each trait using geneAtlas summary statistics
update_trait_to_variants_from_gene_atlas!(trait_to_variants, trait_key_map;
gene_atlas_dir=gene_atlas_dir,
remove_ga_data=remove_ga_data,
maf_threshold=maf_threshold,
pvalue_threshold=pvalue_threshold,
distance_threshold=distance_threshold,
max_variants=max_variants,
bgen_prefix=bgen_prefix
)
end
return trait_to_variants
end

Expand All @@ -214,7 +218,7 @@ function get_dataset_and_validated_estimands(
verbosity=0
)
verbosity > 0 && @info("Calling genotypes.")
variants_set = Set(string.(vcat(values(trait_to_variants)...)))
variants_set = Set(string.(union(values(trait_to_variants)...)))

genotypes = TargeneCore.call_genotypes(
bgen_prefix,
Expand Down Expand Up @@ -315,11 +319,12 @@ function read_and_validate_estimands(estimands_prefix)
end

"""
simulation_inputs_from_gene_atlas(
realistic_simulation_inputs(
estimands_prefix,
bgen_prefix,
traits_file,
pcs_file;
sample_gene_atlas_hits=true,
gene_atlas_dir="gene_atlas_data",
remove_ga_data=true,
trait_table_path=joinpath("assets", "Traits_Table_GeneATLAS.csv"),
Expand All @@ -334,8 +339,8 @@ end
verbosity=0,
)
This function generates input files for realistic simulations using
variants identified from the geneATLAS.
This function generates input files for realistic simulations optionally sampling
variants identified from the geneATLAS (`sample_gene_atlas_hits`).
## What files are Generated ?
Expand Down Expand Up @@ -374,11 +379,12 @@ if `remove_ga_data`. For each outcome, variants are selected if:
Finally, a maximum of `max_variants` is retained per outcome.
"""
function simulation_inputs_from_gene_atlas(
function realistic_simulation_inputs(
estimands_prefix,
bgen_prefix,
traits_file,
pcs_file;
sample_gene_atlas_hits=true,
gene_atlas_dir="gene_atlas_data",
remove_ga_data=true,
trait_table_path=joinpath("assets", "Traits_Table_GeneATLAS.csv"),
Expand All @@ -390,12 +396,15 @@ function simulation_inputs_from_gene_atlas(
batchsize=10,
positivity_constraint=0,
call_threshold=0.9,
variants_regex="^(rs[0-9]*|Affx)",
verbosity=0,
)
Random.seed!(123)
estimands = read_and_validate_estimands(estimands_prefix)
# Trait to variants from geneATLAS
trait_to_variants = get_trait_to_variants(estimands;
trait_to_variants = get_trait_to_variants(estimands;
sample_gene_atlas_hits=sample_gene_atlas_hits,
variants_regex=Regex(variants_regex),
verbosity=verbosity,
gene_atlas_dir=gene_atlas_dir,
remove_ga_data=remove_ga_data,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,21 +107,21 @@ end
estimands = linear_interaction_dataset_ATEs().estimands
push!(estimands, ATE(outcome=:Ycont, treatment_values=(T₃=(case=1, control=0),), treatment_confounders=(:W,)))
# Empty regex
trait_to_variants = Simulations.get_trait_to_variants_from_estimands(estimands; regex=r"")
trait_to_variants = Simulations.get_trait_to_variants_from_estimands(estimands; variants_regex=r"")
@test trait_to_variants == Dict(
"Ycont" => Set(["T₁", "T₃"]),
"Ybin" => Set(["T₁", "T₂"]),
"Ycount" => Set(["T₁"])
)
# T₂ regex
trait_to_variants = Simulations.get_trait_to_variants_from_estimands(estimands; regex=r"T₂")
trait_to_variants = Simulations.get_trait_to_variants_from_estimands(estimands; variants_regex=r"T₂")
@test trait_to_variants == Dict(
"Ycont" => Set(),
"Ybin" => Set(["T₂"]),
"Ycount" => Set()
)
# T₁ regex
trait_to_variants = Simulations.get_trait_to_variants_from_estimands(estimands; regex=r"T₁")
trait_to_variants = Simulations.get_trait_to_variants_from_estimands(estimands; variants_regex=r"T₁")
@test trait_to_variants == Dict(
"Ycont" => Set(["T₁"]),
"Ybin" => Set(["T₁"]),
Expand All @@ -147,10 +147,60 @@ end
)
end

@testset "Test simulation_inputs_from_gene_atlas" begin
# The function `simulation_inputs_from_gene_atlas` is hard to test end to end due to
# data limitations. It is split into 3 subfunctions that we here test sequentially but
# with different data.
@testset "Test realistic_simulation_inputs: sample_gene_atlas_hits=false" begin
verbosity = 0
tmpdir = mktempdir()
estimands_prefix = joinpath(tmpdir, "estimands.jls")
bgen_prefix = joinpath(TARGENCORE_TESTDIR, "data", "ukbb", "imputed" ,"ukbb")
traits_file = joinpath(TARGENCORE_TESTDIR, "data", "traits_1.csv")
pcs_file = joinpath(TARGENCORE_TESTDIR, "data", "pcs.csv")
output_prefix = joinpath(tmpdir, "realistic_inputs")
estimands, _ = estimands_and_traits_to_variants_matching_bgen()
serialize(estimands_prefix, TMLE.Configuration(estimands=estimands))

copy!(ARGS, [
"realistic-simulation-inputs",
estimands_prefix,
bgen_prefix,
traits_file,
pcs_file,
"--sample-gene-atlas-hits=false",
"--ga-download-dir=gene_atlas_data",
"--remove-ga-data=true",
string("--ga-trait-table=", joinpath(PKGDIR, "assets", "Traits_Table_GeneATLAS.csv")),
"--maf-threshold=0.01",
"--pvalue-threshold=1e-5",
"--distance-threshold=1e6",
"--max-variants=100",
string("--output-prefix=", output_prefix),
"--batchsize=10",
"--positivity-constraint=0",
"--call-threshold=0.9",
"--verbosity=0",
"--variants-regex=^RS"
]
)
Simulations.julia_main()

conditional_densities = Set([JSON.parsefile(f) for f in TargeneCore.files_matching_prefix(string(output_prefix, ".conditional_density"))])
@test conditional_densities == Set([
Dict("parents" => Any["RSID_2", "COV_1", "PC2", "21003", "PC1"], "outcome" => "BINARY_2")
Dict("parents" => Any["PC2", "PC1"], "outcome" => "TREAT_1")
Dict("parents" => Any["PC2", "PC1"], "outcome" => "RSID_2")
Dict("parents" => Any["RSID_2", "22001", "PC2", "RSID_198", "PC1", "21003", "COV_1"], "outcome" => "CONTINUOUS_2")
Dict("parents" => Any["TREAT_1", "RSID_2", "22001", "PC2", "RSID_198", "PC1"], "outcome" => "BINARY_1")
Dict("parents" => Any["PC2", "PC1"], "outcome" => "RSID_198")
])
dataset = Arrow.Table(string(output_prefix, ".data.arrow")) |> DataFrame
@test names(dataset) == ["SAMPLE_ID", "BINARY_1", "BINARY_2", "CONTINUOUS_1", "CONTINUOUS_2", "COV_1", "21003", "22001", "TREAT_1", "PC1", "PC2", "RSID_2", "RSID_198"]
output_config = deserialize(string(output_prefix, ".estimands_1.jls"))
@test length(estimands) == length(output_config.estimands)
end

@testset "Test realistic_simulation_inputs: sample_gene_atlas_hits=true" begin
# The function `realistic_simulation_inputs` is hard to test end to end when involving sampling
# variants from the gene-atlas. It is split into 3 subfunctions that we here test sequentially but
# each with different data.
verbosity = 0
# Here we use the real geneATLAS data
tmpdir = mktempdir()
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ TESTDIR = joinpath(pkgdir(Simulations), "test")
@testset "Simulations.jl" begin
# Unit Tests
@test include(joinpath(TESTDIR, "utils.jl"))
@test include(joinpath(TESTDIR, "inputs_from_gene_atlas.jl"))
@test include(joinpath(TESTDIR, "realistic_simulation_inputs.jl"))

@test include(joinpath(TESTDIR, "density_estimation", "glm.jl"))
@test include(joinpath(TESTDIR, "density_estimation", "neural_net.jl"))
Expand Down

0 comments on commit d2dd7c0

Please sign in to comment.