Skip to content

Commit

Permalink
faster genotype calling and estimands generation
Browse files Browse the repository at this point in the history
  • Loading branch information
Olivier Labayle committed Dec 26, 2023
1 parent 25cdab7 commit 8fb9c43
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 35 deletions.
20 changes: 12 additions & 8 deletions src/tmle_inputs/allele_independent_estimands.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ function allele_independent_estimands(parsed_args)
call_threshold = parsed_args["call-threshold"]
bgen_prefix = parsed_args["bgen-prefix"]
positivity_constraint = parsed_args["positivity-constraint"]
traits = TargeneCore.read_data(parsed_args["traits"])
pcs = TargeneCore.read_data(parsed_args["pcs"])
traits = read_data(parsed_args["traits"])
pcs = read_data(parsed_args["pcs"])
config = YAML.load_file(parsed_args["allele-independent"]["config"])

# Variables
Expand All @@ -39,20 +39,24 @@ function allele_independent_estimands(parsed_args)
outcomes = filter(x -> x nonoutcomes, Symbol.(names(traits)))

# Genotypes and final dataset
variants_set = Set(TargeneCore.retrieve_variants_list(variants))
genotypes = TargeneCore.call_genotypes(bgen_prefix, variants_set, call_threshold)
dataset = TargeneCore.merge(traits, pcs, genotypes)
variants_set = Set(retrieve_variants_list(variants))
genotypes = call_genotypes(bgen_prefix, variants_set, call_threshold)
dataset = merge(traits, pcs, genotypes)
Arrow.write(string(outprefix, ".data.arrow"), dataset)

# Estimands
for (groupname, variants_dict) variants
for prod Iterators.product(values(variants_dict)...)
treatments = vcat(collect(Symbol.(prod)), extra_treatments)
treatments_levels = TMLE.unique_treatment_values(dataset, treatments)
freq_table = positivity_constraint !== nothing ? TMLE.frequency_table(dataset, keys(treatments_levels)) : nothing
for outcome in outcomes
Ψ = generateIATEs(dataset, treatments, outcome,
confounders = confounders,
Ψ = generateIATEs(
treatments_levels, outcome;
confounders=confounders,
outcome_extra_covariates=outcome_extra_covariates,
positivity_constraint=positivity_constraint,
freq_table=freq_table,
positivity_constraint=positivity_constraint
)
ncomponents = length.args)
if ncomponents > 0
Expand Down
30 changes: 14 additions & 16 deletions src/tmle_inputs/tmle_inputs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ function read_bgen(filepath)
return Bgen(filepath, sample_path=sample_filepath, idx_path=idx_filepath)
end

all_snps_called(found_variants::Set{<:AbstractString}, variants::Set{<:AbstractString}) =
variants == found_variants

"""
genotypes_encoding(variant)
Expand All @@ -63,45 +60,46 @@ function genotypes_encoding(variant)
return [all₁*all₁, all₁*all₂, all₂*all₂]
end

NotAllVariantsFoundError(found_snps, snp_list) =
ArgumentError(string("Some variants were not found in the genotype files: ", join(setdiff(snp_list, found_snps), ", ")))
NotAllVariantsFoundError(rsids) =
ArgumentError(string("Some variants were not found in the genotype files: ", join(rsids, ", ")))

NotBiAllelicOrUnphasedVariantError(rsid) = ArgumentError(string("Variant: ", rsid, " is not bi-allelic or not unphased."))
"""
bgen_files(snps, bgen_prefix)
This function assumes the UK-Biobank structure
"""
function call_genotypes(bgen_prefix::String, variants::Set{<:AbstractString}, threshold::Real)
function call_genotypes(bgen_prefix::String, query_rsids::Set{<:AbstractString}, threshold::Real)
query_rsids = copy(query_rsids)
chr_dir_, prefix_ = splitdir(bgen_prefix)
chr_dir = chr_dir_ == "" ? "." : chr_dir_
genotypes = nothing
found_variants = Set{String}()
for filename in readdir(chr_dir)
all_snps_called(found_variants, variants) ? break : nothing
length(query_rsids) == 0 ? break : nothing
if is_numbered_chromosome_file(filename, prefix_)
bgenfile = read_bgen(joinpath(chr_dir_, filename))
chr_genotypes = DataFrame(SAMPLE_ID=bgenfile.samples)
for variant in BGEN.iterator(bgenfile)
rsid_ = rsid(variant)
if rsid_ variants
push!(found_variants, rsid_)
bgen_rsids = Set(rsids(bgenfile))
for query_rsid in query_rsids
if query_rsid bgen_rsids
pop!(query_rsids, query_rsid)
variant = variant_by_rsid(bgenfile, query_rsid)
if n_alleles(variant) != 2
@warn("Skipping $rsid_, not bi-allelic")
@warn("Skipping $query_rsid, not bi-allelic")

Check warning on line 88 in src/tmle_inputs/tmle_inputs.jl

View check run for this annotation

Codecov / codecov/patch

src/tmle_inputs/tmle_inputs.jl#L88

Added line #L88 was not covered by tests
continue
end
minor_allele_dosage!(bgenfile, variant)
variant_genotypes = genotypes_encoding(variant)
probabilities = probabilities!(bgenfile, variant)
size(probabilities, 1) != 3 && throw(NotBiAllelicOrUnphasedVariantError(rsid_))
chr_genotypes[!, rsid_] = call_genotypes(probabilities, variant_genotypes, threshold)
size(probabilities, 1) != 3 && throw(NotBiAllelicOrUnphasedVariantError(query_rsid))
chr_genotypes[!, query_rsid] = call_genotypes(probabilities, variant_genotypes, threshold)
end
end
genotypes = genotypes isa Nothing ? chr_genotypes :
innerjoin(genotypes, chr_genotypes, on=:SAMPLE_ID)
end
end
all_snps_called(found_variants, variants) || throw(NotAllVariantsFoundError(found_variants, variants))
length(query_rsids) == 0 || throw(NotAllVariantsFoundError(query_rsids))
return genotypes
end

Expand Down
8 changes: 4 additions & 4 deletions test/tmle_inputs/allele_independent_estimands.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ include(joinpath(TESTDIR, "tmle_inputs", "test_utils.jl"))
tmle_inputs(parsed_args)
# Check dataset
trait_data = DataFrame(Arrow.Table(joinpath(tmpdir, "final.data.arrow")))
@test names(trait_data) == [
@test sort(names(trait_data)) == sort([
"SAMPLE_ID", "BINARY_1", "BINARY_2", "CONTINUOUS_1", "CONTINUOUS_2",
"COV_1", "21003", "22001", "TREAT_1", "PC1", "PC2", "RSID_2", "RSID_102",
"RSID_17", "RSID_198", "RSID_99"]
"RSID_17", "RSID_198", "RSID_99"])
@test size(trait_data) == (490, 16)
# Check estimands
tf_estimands = Dict(:TF1 => [], :TF2 => [])
Expand Down Expand Up @@ -71,10 +71,10 @@ end
tmle_inputs(parsed_args)
# Check dataset
trait_data = DataFrame(Arrow.Table(joinpath(tmpdir, "final.data.arrow")))
@test names(trait_data) == [
@test sort(names(trait_data)) == sort([
"SAMPLE_ID", "BINARY_1", "BINARY_2", "CONTINUOUS_1", "CONTINUOUS_2",
"COV_1", "21003", "22001", "TREAT_1", "PC1", "PC2", "RSID_2", "RSID_102",
"RSID_17", "RSID_198", "RSID_99"]
"RSID_17", "RSID_198", "RSID_99"])
@test size(trait_data) == (490, 16)
# Check estimands
tf_estimands = Dict(:TF1 => [], :TF2 => [])
Expand Down
12 changes: 6 additions & 6 deletions test/tmle_inputs/from_actors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,10 @@ end

## Dataset file
trait_data = DataFrame(Arrow.Table("final.data.arrow"))
@test names(trait_data) == [
@test sort(names(trait_data)) == sort([
"SAMPLE_ID", "BINARY_1", "BINARY_2", "CONTINUOUS_1", "CONTINUOUS_2",
"COV_1", "21003", "22001", "TREAT_1", "PC1", "PC2", "RSID_2", "RSID_102",
"RSID_17", "RSID_198", "RSID_99"]
"RSID_17", "RSID_198", "RSID_99"])
@test size(trait_data) == (490, 16)

## Output estimands:
Expand Down Expand Up @@ -366,9 +366,9 @@ end

## Dataset file
traits = DataFrame(Arrow.Table("final.data.arrow"))
@test names(traits) == [
@test sort(names(traits)) == sort([
"SAMPLE_ID", "BINARY_1", "BINARY_2", "COV_1", "21003", "22001",
"PC1", "PC2", "RSID_2", "RSID_102", "RSID_17", "RSID_198", "RSID_99"]
"PC1", "PC2", "RSID_2", "RSID_102", "RSID_17", "RSID_198", "RSID_99"])
@test size(traits) == (490, 13)

# Parameter files:
Expand Down Expand Up @@ -449,10 +449,10 @@ end

## Dataset file
trait_data = DataFrame(Arrow.Table("final.data.arrow"))
@test names(trait_data) == [
@test sort(names(trait_data)) == sort([
"SAMPLE_ID", "BINARY_1", "BINARY_2", "CONTINUOUS_1", "CONTINUOUS_2",
"COV_1", "21003", "22001", "TREAT_1", "PC1", "PC2", "RSID_2", "RSID_102",
"RSID_17", "RSID_198", "RSID_99"]
"RSID_17", "RSID_198", "RSID_99"])
@test size(trait_data) == (490, 16)

## Parameter file:
Expand Down
2 changes: 1 addition & 1 deletion test/tmle_inputs/tmle_inputs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ end

# With missing variants -> throw ArgumentError
variants = Set(["TOTO"])
@test_throws TargeneCore.NotAllVariantsFoundError([], variants) TargeneCore.call_genotypes(bgen_dir, variants, 0.95;)
@test_throws TargeneCore.NotAllVariantsFoundError(variants) TargeneCore.call_genotypes(bgen_dir, variants, 0.95;)
end


Expand Down

0 comments on commit 8fb9c43

Please sign in to comment.