Skip to content

Commit

Permalink
update pipelines for new burden directory structure
Browse files Browse the repository at this point in the history
  • Loading branch information
bfclarke committed May 22, 2024
1 parent 669140b commit 7d13b7b
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 181 deletions.
199 changes: 80 additions & 119 deletions deeprvat/deeprvat/associate.py

Large diffs are not rendered by default.

20 changes: 10 additions & 10 deletions deeprvat/deeprvat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ def cli():
@click.option("--baseline-results", type=click.Path(exists=True), multiple=True)
@click.option("--baseline-results-out", type=click.Path())
@click.option("--seed-genes-out", type=click.Path())
@click.option("--regenie-options", type=str, multiple=True)
# @click.option("--regenie-options", type=str, multiple=True)
@click.argument("old_config_file", type=click.Path(exists=True))
@click.argument("new_config_file", type=click.Path())
def update_config(
association_only: bool,
phenotype: Optional[str],
baseline_results: Tuple[str],
baseline_results_out: Optional[str],
regenie_options: Optional[Tuple[str]],
# regenie_options: Optional[Tuple[str]],
seed_genes_out: Optional[str],
old_config_file: str,
new_config_file: str,
Expand Down Expand Up @@ -73,15 +73,15 @@ def update_config(
with open(old_config_file) as f:
config = yaml.safe_load(f)

if regenie_options is not None:
try:
existing_regenie_options = config["regenie"]["step_2"]["options"]
except KeyError:
existing_regenie_options = []
# if regenie_options is not None:
# try:
# existing_regenie_options = config["regenie"]["step_2"]["options"]
# except KeyError:
# existing_regenie_options = []

config["regenie"] = config.get("regenie", {})
config["regenie"]["step2"] = config["regenie"].get("step_2", {})
config["regenie"]["step_2"]["options"] = existing_regenie_options + list(regenie_options)
# config["regenie"] = config.get("regenie", {})
# config["regenie"]["step2"] = config["regenie"].get("step_2", {})
# config["regenie"]["step_2"]["options"] = existing_regenie_options + list(regenie_options)


if phenotype is not None:
Expand Down
22 changes: 19 additions & 3 deletions pipelines/association_testing/association_dataset.snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,27 @@ rule association_dataset:
input:
config = '{phenotype}/deeprvat/hpopt_config.yaml'
output:
'{phenotype}/deeprvat/association_dataset.pkl'
temp('{phenotype}/deeprvat/association_dataset.pkl')
threads: 4
resources:
mem_mb = lambda wildcards, attempt: 32000 * (attempt + 1),
load = 64000
mem_mb = lambda wildcards, attempt: 32000 * (attempt + 1)
priority: 30
shell:
'deeprvat_associate make-dataset '
+ debug +
"--skip-genotypes "
'{input.config} '
'{output}'


rule association_dataset_burdens:
input:
config = f'{phenotypes[0]}/deeprvat/hpopt_config.yaml'
output:
temp('burdens/association_dataset.pkl')
threads: 4
resources:
mem_mb = lambda wildcards, attempt: 32000 * (attempt + 1)
priority: 30
shell:
'deeprvat_associate make-dataset '
Expand Down
38 changes: 15 additions & 23 deletions pipelines/association_testing/burdens.snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ n_burden_chunks = config.get('n_burden_chunks', 1) if not debug_flag else 2
n_avg_chunks = config.get('n_avg_chunks', 40)
n_bags = config['training']['n_bags'] if not debug_flag else 3
n_repeats = config['n_repeats']
model_path = Path(config.get("pretrained_model_path", "pretrained_models"))
model_path = Path("models")

if not "cv_exp" in globals():
cv_exp = config.get("cv_exp", False)
Expand All @@ -20,25 +20,20 @@ config_file_prefix = (
)



rule average_burdens:
input:
chunks = [
(f'{p}/deeprvat/burdens/chunk{c}.' +
("finished" if p == phenotypes[0] else "linked"))
for p in phenotypes
for c in range(n_burden_chunks)
] if not cv_exp else '{phenotype}/deeprvat/burdens/merging.finished'
f'burdens/chunk{chunk}.finished' for chunk in range(n_burden_chunks)
] if not cv_exp else 'burdens/merging.finished'
output:
'{phenotype}/deeprvat/burdens/logs/burdens_averaging_{chunk}.finished',
temp('burdens/burdens_averaging_{chunk}.finished'),
params:
burdens_in = '{phenotype}/deeprvat/burdens/burdens.zarr',
burdens_out = '{phenotype}/deeprvat/burdens/burdens_average.zarr',
burdens_in = 'burdens/burdens.zarr',
burdens_out = 'burdens/burdens_average.zarr',
repeats = lambda wildcards: ''.join([f'--repeats {r} ' for r in range(int(n_repeats))])
threads: 1
resources:
mem_mb = lambda wildcards, attempt: 4098 + (attempt - 1) * 4098,
load = 4000,
mem_mb = lambda wildcards, attempt: 4098 + (attempt - 1) * 4098
priority: 10,
shell:
' && '.join([
Expand Down Expand Up @@ -71,8 +66,7 @@ rule compute_xy:
prefix = '.'
threads: 8
resources:
mem_mb = lambda wildcards, attempt: 20480 + (attempt - 1) * 4098,
load = lambda wildcards, attempt: 16000 + (attempt - 1) * 4000
mem_mb = lambda wildcards, attempt: 20480 + (attempt - 1) * 4098
shell:
' && '.join([
('deeprvat_associate compute-xy '
Expand All @@ -92,29 +86,28 @@ rule compute_burdens:
f'{model_path}/repeat_{repeat}/best/bag_{bag}.ckpt'
for repeat in range(n_repeats) for bag in range(n_bags)
],
dataset = '{phenotype}/deeprvat/association_dataset.pkl',
data_config = '{phenotype}/deeprvat/hpopt_config.yaml',
dataset = 'burdens/association_dataset.pkl',
data_config = f'{phenotypes[0]}/deeprvat/hpopt_config.yaml',
model_config = model_path / 'config.yaml',
output:
'{phenotype}/deeprvat/burdens/chunk{chunk}.finished'
temp('burdens/chunk{chunk}.finished')
params:
prefix = '.'
threads: 8
resources:
mem_mb = 2000000, # Using this value will tell our modified lsf.profile not to set a memory resource
load = 8000,
mem_mb = 32000,
gpus = 1
shell:
' && '.join([
('deeprvat_associate compute-burdens '
+ debug +
' --n-chunks '+ str(n_burden_chunks) + ' '
' --n-chunks ' + str(n_burden_chunks) + ' '
'--chunk {wildcards.chunk} '
'--dataset-file {input.dataset} '
'{input.data_config} '
'{input.model_config} '
'{input.checkpoints} '
'{params.prefix}/{wildcards.phenotype}/deeprvat/burdens'),
'burdens'),
'touch {output}'
])

Expand All @@ -128,8 +121,7 @@ rule reverse_models:
model_path / "reverse_finished.tmp"
threads: 4
resources:
mem_mb = 20480,
load = 20480
mem_mb = 20480
shell:
" && ".join([
("deeprvat_associate reverse-models "
Expand Down
36 changes: 17 additions & 19 deletions pipelines/association_testing/regress_eval.snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ config_file_prefix = (
"cv_split0/deeprvat/" if cv_exp else ""
)
########### Average regression
rule all_evaluate:
input:
expand("{phenotype}/deeprvat/eval/significant.parquet",
phenotype=phenotypes),
expand("{phenotype}/deeprvat/eval/all_results.parquet",
phenotype=phenotypes),
# rule all_evaluate:
# input:
# expand("{phenotype}/deeprvat/eval/significant.parquet",
# phenotype=phenotypes),
# expand("{phenotype}/deeprvat/eval/all_results.parquet",
# phenotype=phenotypes),

rule evaluate:
input:
Expand Down Expand Up @@ -68,15 +68,12 @@ rule combine_regression_chunks:
rule regress:
input:
config = f"{config_file_prefix}{{phenotype}}/deeprvat/hpopt_config.yaml",
chunks = lambda wildcards: (
[] if wildcards.phenotype == phenotypes[0]
else expand('{{phenotype}}/deeprvat/burdens/chunk{chunk}.linked',
chunk=range(n_burden_chunks))
) if not cv_exp else '{phenotype}/deeprvat/burdens/merging.finished',
phenotype_0_chunks = expand(
phenotypes[0] + '/deeprvat/burdens/logs/burdens_averaging_{chunk}.finished',
chunks = expand(
'burdens/burdens_averaging_{chunk}.finished',
chunk=range(n_avg_chunks)
),
x = '{phenotype}/deeprvat/xy/x.zarr',
y = '{phenotype}/deeprvat/xy/y.zarr',
output:
temp('{phenotype}/deeprvat/average_regression_results/burden_associations_{chunk}.parquet'),
threads: 2
Expand All @@ -85,19 +82,20 @@ rule regress:
# mem_mb = 16000,
load = lambda wildcards, attempt: 28000 + (attempt - 1) * 4000
params:
burden_file = f'{phenotypes[0]}/deeprvat/burdens/burdens_average.zarr',
burden_dir = '{phenotype}/deeprvat/burdens',
burden_file = 'burdens/burdens_average.zarr',
xy_dir = "{phenotype}/deeprvat/xy",
# burden_dir = 'burdens',
out_dir = '{phenotype}/deeprvat/average_regression_results'
shell:
'deeprvat_associate regress '
+ debug +
'--chunk {wildcards.chunk} '
'--n-chunks ' + str(n_regression_chunks) + ' '
'--use-bias '
'--repeat 0 '
'--burden-file {params.burden_file} '
# '--repeat 0 '
+ do_scoretest +
'{input.config} '
'{params.burden_dir} ' #TODO make this w/o repeats
"{params.xy_dir} "
# '{params.burden_dir} ' #TODO make this w/o repeats
"{params.burden_file} "
'{params.out_dir}'

4 changes: 1 addition & 3 deletions pipelines/association_testing/regress_eval_regenie.snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ phenotypes = list(phenotypes.keys()) if type(phenotypes) == dict else phenotypes

n_burden_chunks = config.get('n_burden_chunks', 1) if not debug_flag else 2

burdens = Path(config["burden_file"])

regenie_config_step1 = config["regenie"]["step_1"]
regenie_config_step2 = config["regenie"]["step_2"]
regenie_step1_bsize = regenie_config_step1["bsize"]
Expand Down Expand Up @@ -230,7 +228,7 @@ rule make_regenie_burdens:
input:
gene_file = config["data"]["dataset_config"]["rare_embedding"]["config"]["gene_file"],
gtf_file = config["gtf_file"],
burdens = burdens,
burdens = 'burdens/burdens_average.zarr',
genes = burdens.parent / "genes.npy",
samples = burdens.parent / "sample_ids.zarr",
datasets = expand("{phenotype}/deeprvat/association_dataset.pkl",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ n_bags = config['training']['n_bags'] if not debug_flag else 3
n_repeats = config['n_repeats']
debug = '--debug ' if debug_flag else ''
do_scoretest = '--do-scoretest ' if config.get('do_scoretest', False) else ''
model_path = Path(config.get("pretrained_model_path", "pretrained_models"))
model_path = Path("models")

wildcard_constraints:
repeat="\d+",
Expand All @@ -33,4 +33,5 @@ rule all:
rule all_association_dataset:
input:
expand('{phenotype}/deeprvat/association_dataset.pkl',
phenotype=phenotypes)
phenotype=phenotypes),
'association_dataset_burdens.pkl',
2 changes: 1 addition & 1 deletion pipelines/association_testing_pretrained.snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ n_repeats = config['n_repeats']
debug = '--debug ' if debug_flag else ''
do_scoretest = '--do-scoretest ' if config.get('do_scoretest', False) else ''
tensor_compression_level = config['training'].get('tensor_compression_level', 1)
model_path = Path(config.get("pretrained_model_path", "pretrained_models"))
model_path = Path("models")

wildcard_constraints:
repeat="\d+",
Expand Down
2 changes: 1 addition & 1 deletion pipelines/association_testing_pretrained_regenie.snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ n_bags = config['training']['n_bags'] if not debug_flag else 3
n_repeats = config['n_repeats']
debug = '--debug ' if debug_flag else ''
do_scoretest = '--do-scoretest ' if config.get('do_scoretest', False) else ''
model_path = Path(config.get("pretrained_model_path", "pretrained_models"))
model_path = Path("models")

wildcard_constraints:
repeat="\d+",
Expand Down

0 comments on commit 7d13b7b

Please sign in to comment.