Skip to content

Commit

Permalink
NamedTuple to Dictionary changes for overhead alleviation
Browse files Browse the repository at this point in the history
  • Loading branch information
joshua-slaughter committed Jul 29, 2024
1 parent 13d2bd2 commit 9ed04e3
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 54 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,5 @@ sandbox.jl
# Jupyter
*checkpoints
src/generate_results.jl
test_grid.csv
test/tl_inputs/real_data.jl
18 changes: 9 additions & 9 deletions Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.10.3"
julia_version = "1.10.4"
manifest_format = "2.0"
project_hash = "154e3366f7fc60caf52dba209a93aced22d33ef0"

Expand Down Expand Up @@ -1916,9 +1916,9 @@ uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[[deps.ProgressMeter]]
deps = ["Distributed", "Printf"]
git-tree-sha1 = "763a8ceb07833dd51bb9e3bbca372de32c0605ad"
git-tree-sha1 = "80686d28ecb3ee7fb3ac5371cacaa0d673eb0d4a"
uuid = "92933f4c-e287-5a05-a399-4b506db050ca"
version = "1.10.0"
version = "1.10.1"

[[deps.PtrArrays]]
git-tree-sha1 = "f011fbb92c4d401059b2212c05c0601b70f8b759"
Expand Down Expand Up @@ -2344,9 +2344,9 @@ uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c"
version = "7.2.1+1"

[[deps.TMLE]]
deps = ["AbstractDifferentiation", "CategoricalArrays", "Combinatorics", "Distributions", "GLM", "Graphs", "HypothesisTests", "LogExpFunctions", "MLJBase", "MLJGLMInterface", "MLJModels", "MetaGraphsNext", "Missings", "PrecompileTools", "Random", "SplitApplyCombine", "Statistics", "TableOperations", "Tables", "Zygote"]
git-tree-sha1 = "86f8ca5c47ab2b96a871beaeaae649a865d7d3f6"
repo-rev = "agnostic_composed"
deps = ["AbstractDifferentiation", "CategoricalArrays", "Combinatorics", "DataFrames", "Dictionaries", "Distributions", "GLM", "Graphs", "HypothesisTests", "LogExpFunctions", "MLJBase", "MLJGLMInterface", "MLJModels", "MetaGraphsNext", "Missings", "PrecompileTools", "Random", "SplitApplyCombine", "Statistics", "TableOperations", "Tables", "Zygote"]
git-tree-sha1 = "e547cc7fd91ca21cadce6b2ef71c22a7407ff42d"
repo-rev = "frequency_table"
repo-url = "https://github.com/TARGENE/TMLE.jl"
uuid = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf"
version = "0.16.1"
Expand Down Expand Up @@ -2380,10 +2380,10 @@ uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c"
version = "1.0.1"

[[deps.Tables]]
deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits"]
git-tree-sha1 = "cb76cf677714c095e535e3501ac7954732aeea2d"
deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "OrderedCollections", "TableTraits"]
git-tree-sha1 = "598cd7c1f68d1e205689b1c2fe65a9f85846f297"
uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
version = "1.11.1"
version = "1.12.0"

[[deps.Tar]]
deps = ["ArgTools", "SHA"]
Expand Down
122 changes: 82 additions & 40 deletions src/tl_inputs/allele_independent_estimands.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ function treatment_tuples_from_groups(treatments_lists, orders)
return sort(treatment_combinations)
end

"""
treatment_from_variant(variant, dataset)
Generate a key-value pair (dicitionary) for treatment structs.
"""
function treatment_from_variant(variant::String, dataset::DataFrame)
variant_levels = sort(levels(dataset[!, variant], skipmissing=true))
return Dict{Symbol, Vector{UInt8}}(Symbol(variant)=>variant_levels)
end

function try_append_new_estimands!(
estimands,
dataset,
Expand Down Expand Up @@ -59,36 +69,40 @@ function try_append_new_estimands!(
end
end

function try_index_new_estimands!(
estimands,
index,

function try_estimands(
treatments,
dataset,
estimand_constructor,
treatments,
outcomes,
confounders;
outcome_extra_covariates=[],
positivity_constraint=0.,
verbosity=1
)
local Ψ
try
Ψ = factorialEstimands(
estimand_constructor, treatments, outcomes;
confounders=confounders,
dataset=dataset,
outcome_extra_covariates=outcome_extra_covariates,
positivity_constraint=positivity_constraint,
verbosity=verbosity-1
)
catch e
if !(e == ArgumentError("No component passed the positivity constraint."))
throw(e)
estimands = []
for t in treatments
local Ψ
try
Ψ = factorialEstimands(
estimand_constructor, t, outcomes;
confounders=confounders,
dataset=dataset,
outcome_extra_covariates=outcome_extra_covariates,
positivity_constraint=positivity_constraint,
verbosity=verbosity-1)

catch e
if !(e == ArgumentError("No component passed the positivity constraint."))
throw(e)
end
else
append!(estimands, Ψ)
end
else
estimands[index] = Ψ
end
return estimands
end

function estimands_from_groups(estimands_configs, dataset, variants_config, outcomes, confounders;
extra_treatments=[],
outcome_extra_covariates=[],
Expand Down Expand Up @@ -156,31 +170,43 @@ function estimands_from_flat_list(estimands_configs, dataset, variants, outcomes
return estimands
end

function gwas_estimands(dataset, variants, outcomes, confounders;
# For significant speedup (NOT a fix) add @nospecialize to try...estimands()
function gwas_estimands_chunks(dataset, treatments, outcomes, confounders;
outcome_extra_covariates = [],
positivity_constraint=0.,
verbosity=0
)
chunks = Iterators.partition(treatments, Int(ceil(length(treatments)/Threads.nthreads())))
tasks = map(chunks) do chunk
Threads.@spawn try_estimands(chunk, dataset, ATE, outcomes, confounders;
outcome_extra_covariates=outcome_extra_covariates,
positivity_constraint=positivity_constraint,
verbosity=verbosity)
end
chunk_estimands = fetch.(tasks)
return vcat(chunk_estimands...)
end

function gwas_estimands_serial(dataset, treatments, outcomes, confounders;
outcome_extra_covariates=[],
positivity_constraint=0.,
verbosity=0,
)
estimands = Vector{Union{Any, Missing}}(undef, length(variants))
fill!(estimands, missing)
Threads.@threads for (index, v) collect(enumerate(variants))
variant_levels = sort(levels(dataset[!, v], skipmissing=true))
treatments = NamedTuple{(Symbol(v),)}([variant_levels])
try_index_new_estimands!(
estimands,
index,
estimands = []
for t treatments
try_append_new_estimands!(
estimands,
dataset,
ATE,
treatments,
t,
outcomes,
confounders;
outcome_extra_covariates=outcome_extra_covariates,
positivity_constraint=positivity_constraint,
verbosity=verbosity
)
end
estimands = vcat(estimands...)
filter!(x -> x !== missing, estimands)
return estimands
end

get_only_file_with_suffix(files, suffix) = files[only(findall(x -> endswith(x, suffix), files))]
Expand All @@ -206,7 +232,7 @@ function get_genotypes_from_beds(bedprefix)
end

function make_genotypes(genotype_prefix, config, call_threshold)
genotypes = if config["type"] == "gwas"
genotypes = if config["type"] == "gwas_parallel" || config["type"] == "gwas_serial" || config["type"] == "gwas"
get_genotypes_from_beds(genotype_prefix)
else
variants_set = Set(retrieve_variants_list(config["variants"]))
Expand Down Expand Up @@ -251,29 +277,45 @@ function allele_independent_estimands(parsed_args)
outcome_extra_covariates=outcome_extra_covariates,
positivity_constraint=positivity_constraint,
verbosity=verbosity)

elseif config_type == "groups"
estimands_from_groups(config["estimands"], dataset, config["variants"], outcomes, confounders;
extra_treatments=extra_treatments,
outcome_extra_covariates=outcome_extra_covariates,
positivity_constraint=positivity_constraint,
verbosity=verbosity
)
elseif config_type == "gwas"
verbosity=verbosity)

elseif config_type == "gwas_parallel"
verbosity > 0 && @info("Generating estimands.")
variants = filter(!=("SAMPLE_ID"), names(genotypes))
treatments = Vector{Dict{Symbol, Vector{UInt8}}}()
for v in variants
push!(treatments, treatment_from_variant(v, dataset))
end
gwas_estimands_chunks(dataset, treatments, outcomes, confounders;
outcome_extra_covariates=outcome_extra_covariates,
positivity_constraint=positivity_constraint,
verbosity=verbosity)

elseif config_type == "gwas_serial"
verbosity > 0 && @info("Generating estimands.")
variants = filter(!=("SAMPLE_ID"), names(genotypes))
gwas_estimands(dataset, variants, outcomes, confounders;
treatments = Vector{Dict{Symbol, Vector{UInt8}}}()
for v in variants
push!(treatments, treatment_from_variant(v, dataset))
end
gwas_estimands_serial(dataset, treatments, outcomes, confounders;
outcome_extra_covariates=outcome_extra_covariates,
positivity_constraint=positivity_constraint,
verbosity=verbosity
)
verbosity=verbosity)
else
throw(ArgumentError(string("Unknown extraction type: ", config_type, ", use any of: (flat, groups, gwas)")))
end

@assert length(estimands) > 0 "No estimands left, probably due to a too high positivity constraint."

save_estimands(outprefix, groups_ordering(estimands), batchsize)

verbosity > 0 && @info("Saving estimands.")
save_estimands(outprefix, estimands, batchsize)
verbosity > 0 && @info("Done.")

return 0
Expand Down
8 changes: 8 additions & 0 deletions test/data/config_gwas_parallel.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
type: gwas_parallel

outcome_extra_covariates:
- COV_1

extra_confounders:
- 21003
- 22001
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
type: gwas
type: gwas_serial

outcome_extra_covariates:
- COV_1
Expand Down
89 changes: 85 additions & 4 deletions test/tl_inputs/loco_gwas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ function check_estimands_levels_order(estimands)
end
end
end
@testset "Test loco-gwas from flat list: no positivity constraint" begin
@testset "Test loco-gwas serial: no positivity constraint" begin
tmpdir = mktempdir()
parsed_args = Dict(
"verbosity" => 0,
Expand All @@ -45,7 +45,7 @@ end

"allele-independent" => Dict{String, Any}(
"call-threshold" => nothing,
"config" => joinpath(TESTDIR, "data", "config_gwas.yaml"),
"config" => joinpath(TESTDIR, "data", "config_gwas_serial.yaml"),
"traits" => joinpath(TESTDIR, "data", "ukbb_traits.csv"),
"pcs" => joinpath(TESTDIR, "data", "ukbb_pcs.csv"),
"genotype-prefix" => joinpath(TESTDIR, "data", "ukbb", "genotypes" ,"ukbb_1."),
Expand Down Expand Up @@ -74,7 +74,7 @@ end
check_estimands_levels_order(estimands)
end

@testset "Test loco-gwas from flat list: positivity constraint" begin
@testset "Test loco-gwas serial: positivity constraint" begin
tmpdir = mktempdir()
parsed_args = Dict(
"verbosity" => 0,
Expand All @@ -86,7 +86,88 @@ end

"allele-independent" => Dict{String, Any}(
"call-threshold" => nothing,
"config" => joinpath(TESTDIR, "data", "config_gwas.yaml"),
"config" => joinpath(TESTDIR, "data", "config_gwas_serial.yaml"),
"traits" => joinpath(TESTDIR, "data", "ukbb_traits.csv"),
"pcs" => joinpath(TESTDIR, "data", "ukbb_pcs.csv"),
"genotype-prefix" => joinpath(TESTDIR, "data", "ukbb", "genotypes" ,"ukbb_1")
),
)
tl_inputs(parsed_args)
# Check dataset
trait_data = DataFrame(Arrow.Table(joinpath(tmpdir, "final.data.arrow")))
@test size(trait_data) == (1940, 886)
# Check estimands
estimands = []
for file in readdir(tmpdir, join=true)
if endswith(file, "jls")
append!(estimands, deserialize(file).estimands)
end
end
@test all(e isa JointEstimand for e in estimands)
summary_stats = get_summary_stats(estimands)
@test summary_stats == DataFrame(
OUTCOME = [:BINARY_1, :BINARY_2, :CONTINUOUS_1, :CONTINUOUS_2, :TREAT_1],
nrow = repeat([777], 5)
)

check_estimands_levels_order(estimands)

end

@testset "Test loco-gwas parallel: no positivity constraint" begin
tmpdir = mktempdir()
parsed_args = Dict(
"verbosity" => 0,
"out-prefix" => joinpath(tmpdir, "final"),
"batch-size" => 5,
"positivity-constraint" => 0.0,

"%COMMAND%" => "allele-independent",

"allele-independent" => Dict{String, Any}(
"call-threshold" => nothing,
"config" => joinpath(TESTDIR, "data", "config_gwas_parallel.yaml"),
"traits" => joinpath(TESTDIR, "data", "ukbb_traits.csv"),
"pcs" => joinpath(TESTDIR, "data", "ukbb_pcs.csv"),
"genotype-prefix" => joinpath(TESTDIR, "data", "ukbb", "genotypes" ,"ukbb_1."),
),
)
tl_inputs(parsed_args)
# Check dataset
trait_data = DataFrame(Arrow.Table(joinpath(tmpdir, "final.data.arrow")))
@test size(trait_data) == (1940, 886)

# Check estimands
estimands = []
for file in readdir(tmpdir, join=true)
if endswith(file, "jls")
append!(estimands, deserialize(file).estimands)
end
end
@test all(e isa JointEstimand for e in estimands)

summary_stats = get_summary_stats(estimands)
@test summary_stats == DataFrame(
OUTCOME = [:BINARY_1, :BINARY_2, :CONTINUOUS_1, :CONTINUOUS_2, :TREAT_1],
nrow = repeat([875], 5)
)

check_estimands_levels_order(estimands)
end

@testset "Test loco-gwas parallel: positivity constraint" begin
tmpdir = mktempdir()
parsed_args = Dict(
"verbosity" => 0,
"out-prefix" => joinpath(tmpdir, "final"),
"batch-size" => 5,
"positivity-constraint" => 0.2,

"%COMMAND%" => "allele-independent",

"allele-independent" => Dict{String, Any}(
"call-threshold" => nothing,
"config" => joinpath(TESTDIR, "data", "config_gwas_parallel.yaml"),
"traits" => joinpath(TESTDIR, "data", "ukbb_traits.csv"),
"pcs" => joinpath(TESTDIR, "data", "ukbb_pcs.csv"),
"genotype-prefix" => joinpath(TESTDIR, "data", "ukbb", "genotypes" ,"ukbb_1")
Expand Down

0 comments on commit 9ed04e3

Please sign in to comment.