diff --git a/deeprvat/deeprvat/associate.py b/deeprvat/deeprvat/associate.py index 5b744096..d2369219 100644 --- a/deeprvat/deeprvat/associate.py +++ b/deeprvat/deeprvat/associate.py @@ -19,7 +19,7 @@ import statsmodels.api as sm import yaml from bgen import BgenWriter -from numcodecs import Blosc +from numcodecs import Blosc, JSON from seak import scoretest from statsmodels.tools.tools import add_constant from torch.utils.data import DataLoader, Dataset, Subset @@ -53,8 +53,7 @@ def get_burden( batch: Dict, agg_models: Dict[str, List[nn.Module]], device: torch.device = torch.device("cpu"), - skip_burdens=False, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute burden scores for rare variants. @@ -76,24 +75,17 @@ def get_burden( with torch.no_grad(): X = batch["rare_variant_annotations"].to(device) burden = [] - if not skip_burdens: - for key in sorted( - list(agg_models.keys()), key=lambda x: int(x.split("_")[1]) - ): - this_agg_models = agg_models[key] - this_burden: torch.Tensor = sum([m(X) for m in this_agg_models]) / len( - this_agg_models - ) - burden.append(this_burden.cpu().numpy()) - burden = np.concatenate(burden, axis=2) - else: - burden = None + for key in sorted(list(agg_models.keys()), key=lambda x: int(x.split("_")[1])): + this_agg_models = agg_models[key] + this_burden: torch.Tensor = sum([m(X) for m in this_agg_models]) / len( + this_agg_models + ) + burden.append(this_burden.cpu().numpy()) + burden = np.concatenate(burden, axis=2) - y = batch["y"] - x = batch["x_phenotypes"] sample_ids = batch["sample"] - return burden, y, x, sample_ids + return burden, sample_ids def separate_parallel_results(results: List) -> Tuple[List, ...]: @@ -117,7 +109,7 @@ def make_dataset_( config: Dict, debug: bool = False, data_key: str = "data", - skip_burdens: bool = False, + skip_genotypes: bool = False, samples: Optional[List[int]] = None, ) -> Dataset: """ @@ -129,6 +121,8 @@ def make_dataset_( :type debug: bool :param data_key: Key for dataset configuration in the config dictionary, defaults to "data". :type data_key: str + :param skip_genotypes: Retrieve only covariates and phenotypes, not genotypes + :type skip_genotypes: bool :param samples: List of sample indices to include in the dataset, defaults to None. :type samples: List[int] :return: Loaded instance of the created dataset. @@ -144,7 +138,7 @@ def make_dataset_( else: gt_variant_args = ( (data_config["gt_file"], data_config["variant_file"]) - if not skip_burdens + if not skip_genotypes else tuple() ) ds = DenseGTDataset( @@ -169,11 +163,11 @@ def make_dataset_( @cli.command() @click.option("--debug", is_flag=True) @click.option("--data-key", type=str, default="data") -@click.option("--skip-burdens", is_flag=True) +@click.option("--skip-genotypes", is_flag=True) @click.argument("config-file", type=click.Path(exists=True)) @click.argument("out-file", type=click.Path()) def make_dataset( - debug: bool, data_key: str, skip_burdens: bool, config_file: str, out_file: str + debug: bool, data_key: str, skip_genotypes: bool, config_file: str, out_file: str ): """ Create a dataset based on the provided configuration and save to a pickle file. @@ -182,6 +176,8 @@ def make_dataset( :type debug: bool :param data_key: Key for dataset configuration in the config dictionary, defaults to "data". :type data_key: str + :param skip_genotypes: Retrieve only covariates and phenotypes, not genotypes + :type skip_genotypes: bool :param config_file: Path to the configuration file. :type config_file: str :param out_file: Path to the output file. @@ -192,7 +188,7 @@ def make_dataset( config = yaml.safe_load(f) ds = make_dataset_( - config, debug=debug, data_key=data_key, skip_burdens=skip_burdens + config, debug=debug, data_key=data_key, skip_genotypes=skip_genotypes ) with open(out_file, "wb") as f: @@ -279,7 +275,7 @@ def compute_xy( with open(dataset_file, "rb") as f: dataset = pickle.load(f) else: - dataset = make_dataset_(data_config) + dataset = make_dataset_(data_config, skip_genotypes=True) sample_ids, x, y = compute_xy_( data_config, @@ -303,10 +299,7 @@ def compute_burdens_( device: torch.device = torch.device("cpu"), bottleneck: bool = False, compression_level: int = 1, - skip_burdens: bool = False, -) -> Tuple[ - np.ndarray, zarr.core.Array, zarr.core.Array, zarr.core.Array, zarr.core.Array -]: +) -> Tuple[np.ndarray, zarr.core.Array, zarr.core.Array]: """ Compute burdens using the PyTorch model for each repeat. @@ -339,21 +332,17 @@ def compute_burdens_( .. note:: Checkpoint models all corresponding to the same repeat are averaged for that repeat. """ - if not skip_burdens: - logger.info("agg_models[*][*].reverse:") - pprint( - { - repeat: [m.reverse for m in models] - for repeat, models in agg_models.items() - } - ) + logger.info("agg_models[*][*].reverse:") + pprint( + {repeat: [m.reverse for m in models] for repeat, models in agg_models.items()} + ) data_config = config["data"] ds_full = ds.dataset if isinstance(ds, Subset) else ds collate_fn = getattr(ds_full, "collate_fn", None) n_total_samples = len(ds) - ds.rare_embedding.skip_embedding = skip_burdens + ds.rare_embedding.skip_embedding = False if chunk is not None: if n_chunks is None: @@ -390,63 +379,37 @@ def compute_burdens_( file=sys.stdout, total=(n_samples // batch_size + (n_samples % batch_size != 0)), ): - this_burdens, this_y, this_x, this_sampleid = get_burden( - batch, agg_models, device=device, skip_burdens=skip_burdens - ) + this_burdens, this_sampleid = get_burden(batch, agg_models, device=device) if i == 0: - if not skip_burdens: - chunk_burden = np.zeros(shape=(n_samples,) + this_burdens.shape[1:]) - chunk_y = np.zeros(shape=(n_samples,) + this_y.shape[1:]) - chunk_x = np.zeros(shape=(n_samples,) + this_x.shape[1:]) - chunk_sampleid = np.zeros(shape=(n_samples)) + chunk_burden = np.zeros(shape=(n_samples,) + this_burdens.shape[1:]) + chunk_sampleid = [""] * n_samples logger.info(f"Batch size: {batch['rare_variant_annotations'].shape}") - if not skip_burdens: - burdens = zarr.open( - Path(cache_dir) / "burdens.zarr", - mode="a", - shape=(n_total_samples,) + this_burdens.shape[1:], - chunks=(1000, 1000, 1), - dtype=np.float32, - compressor=Blosc(clevel=compression_level), - ) - logger.info(f"burdens shape: {burdens.shape}") - else: - burdens = None - - y = zarr.open( - Path(cache_dir) / "y.zarr", - mode="a", - shape=(n_total_samples,) + this_y.shape[1:], - chunks=(None, None), - dtype=np.float32, - compressor=Blosc(clevel=compression_level), - ) - x = zarr.open( - Path(cache_dir) / "x.zarr", + burdens = zarr.open( + Path(cache_dir) / "burdens.zarr", mode="a", - shape=(n_total_samples,) + this_x.shape[1:], - chunks=(None, None), + shape=(n_total_samples,) + this_burdens.shape[1:], + chunks=(1000, 1000, 1), dtype=np.float32, compressor=Blosc(clevel=compression_level), ) + logger.info(f"burdens shape: {burdens.shape}") + sample_ids = zarr.open( Path(cache_dir) / "sample_ids.zarr", mode="a", shape=(n_total_samples), chunks=(None), - dtype=np.float32, - compressor=Blosc(clevel=compression_level), + dtype=str, + compressor=JSON(), ) + start_idx = i * batch_size end_idx = min(start_idx + batch_size, chunk_end) # read from chunk shape - if not skip_burdens: - chunk_burden[start_idx:end_idx] = this_burdens + chunk_burden[start_idx:end_idx] = this_burdens - chunk_y[start_idx:end_idx] = this_y - chunk_x[start_idx:end_idx] = this_x chunk_sampleid[start_idx:end_idx] = this_sampleid if debug: @@ -457,11 +420,7 @@ def compute_burdens_( if bottleneck and i > 20: break - if not skip_burdens: - burdens[chunk_start:chunk_end] = chunk_burden - - y[chunk_start:chunk_end] = chunk_y - x[chunk_start:chunk_end] = chunk_x + burdens[chunk_start:chunk_end] = chunk_burden sample_ids[chunk_start:chunk_end] = chunk_sampleid if torch.cuda.is_available(): @@ -469,7 +428,7 @@ def compute_burdens_( "Max GPU memory allocated: " f"{torch.cuda.max_memory_allocated(0)} bytes" ) - return ds_full.rare_embedding.genes, burdens, y, x, sample_ids + return ds_full.rare_embedding.genes, burdens, sample_ids def make_regenie_input_( @@ -965,7 +924,6 @@ def load_models( @click.option("--n-chunks", type=int) @click.option("--chunk", type=int) @click.option("--dataset-file", type=click.Path(exists=True)) -@click.option("--link-burdens", type=click.Path()) @click.argument("data-config-file", type=click.Path(exists=True)) @click.argument("model-config-file", type=click.Path(exists=True)) @click.argument("checkpoint-files", type=click.Path(exists=True), nargs=-1) @@ -976,7 +934,6 @@ def compute_burdens( n_chunks: Optional[int], chunk: Optional[int], dataset_file: Optional[str], - link_burdens: Optional[str], data_config_file: str, model_config_file: str, checkpoint_files: Tuple[str], @@ -995,8 +952,6 @@ def compute_burdens( :type chunk: Optional[int] :param dataset_file: Path to the dataset file, i.e., association_dataset.pkl. :type dataset_file: Optional[str] - :param link_burdens: Path to burden.zarr file to link. - :type link_burdens: Optional[str] :param data_config_file: Path to the data configuration file. :type data_config_file: str :param model_config_file: Path to the model configuration file. @@ -1025,7 +980,7 @@ def compute_burdens( with open(dataset_file, "rb") as f: dataset = pickle.load(f) else: - dataset = make_dataset_(config) + dataset = make_dataset_(data_config) if torch.cuda.is_available(): logger.info("Using GPU") @@ -1034,12 +989,9 @@ def compute_burdens( logger.info("Using CPU") device = torch.device("cpu") - if link_burdens is None: - agg_models = load_models(model_config, checkpoint_files, device=device) - else: - agg_models = None + agg_models = load_models(model_config, checkpoint_files, device=device) - genes, _, _, _, _ = compute_burdens_( + genes, _, _ = compute_burdens_( debug, data_config, dataset, @@ -1049,15 +1001,10 @@ def compute_burdens( chunk=chunk, device=device, bottleneck=bottleneck, - skip_burdens=(link_burdens is not None), ) logger.info("Saving computed burdens, corresponding genes, and targets") np.save(Path(out_dir) / "genes.npy", genes) - if link_burdens is not None: - source_path = Path(out_dir) / "burdens.zarr" - source_path.unlink(missing_ok=True) - source_path.symlink_to(link_burdens) def regress_on_gene_scoretest( @@ -1261,12 +1208,13 @@ def regress_( @click.option("--n-chunks", type=int, default=1) @click.option("--use-bias", is_flag=True) @click.option("--gene-file", type=click.Path(exists=True)) -@click.option("--repeat", type=int, default=0) +# @click.option("--repeat", type=int, default=0) @click.option("--do-scoretest", is_flag=True) @click.option("--sample-file", type=click.Path(exists=True)) -@click.option("--burden-file", type=click.Path(exists=True)) @click.argument("config-file", type=click.Path(exists=True)) -@click.argument("burden-dir", type=click.Path(exists=True)) +@click.argument("xy-dir", type=click.Path(exists=True)) +# @click.argument("burden-dir", type=click.Path(exists=True)) +@click.argument("burden-file", type=click.Path(exists=True)) @click.argument("out-dir", type=click.Path()) def regress( debug: bool, @@ -1274,13 +1222,15 @@ def regress( n_chunks: int, use_bias: bool, gene_file: str, - repeat: int, + # repeat: int, config_file: str, - burden_dir: str, + xy_dir: str, + # burden_dir: str, + burden_file: str, out_dir: str, do_scoretest: bool, sample_file: Optional[str], - burden_file: Optional[str], + # burden_file: Optional[str], ): """ Perform regression analysis. @@ -1295,8 +1245,8 @@ def regress( :type use_bias: bool :param gene_file: Path to the gene file. :type gene_file: str - :param repeat: Index of the repeat, defaults to 0. - :type repeat: int + # :param repeat: Index of the repeat, defaults to 0. + # :type repeat: int :param config_file: Path to the configuration file. :type config_file: str :param burden_dir: Path to the directory containing burdens.zarr file. @@ -1309,16 +1259,24 @@ def regress( :type sample_file: Optional[str] :return: Regression results saved to out_dir as "burden_associations_{chunk}.parquet" """ - logger.info("Loading saved burdens") + burden_dir = Path(burden_file).parent # if burden_file is not None: # logger.info(f'Loading burdens from {burden_file}') # burdens = zarr.open(burden_file)[:, :, repeat] # else: # burdens = zarr.open(Path(burden_dir) / "burdens.zarr")[:, :, repeat] - logger.info(f"Loading x, y, genes from {burden_dir}") - y = zarr.open(Path(burden_dir) / "y.zarr")[:] - x_pheno = zarr.open(Path(burden_dir) / "x.zarr")[:] - genes = pd.Series(np.load(Path(burden_dir) / "genes.npy")) + logger.info(f"Loading covariates and targets from {xy_dir}") + y = zarr.load(Path(xy_dir) / "y.zarr") + x_pheno = zarr.load(Path(xy_dir) / "x.zarr") + + # Make sure sample IDs agree + try: + assert np.array_equal( + zarr.load(Path(xy_dir) / "sample_ids.zarr"), + zarr.load(Path(burden_dir) / "sample_ids.zarr"), + ) + except: + raise ValueError("Sample IDs in xy_dir and burden_dir disagree") if sample_file is not None: with open(sample_file, "rb") as f: @@ -1330,8 +1288,13 @@ def regress( x_pheno = x_pheno[samples] n_samples = y.shape[0] - assert y.shape[0] == n_samples - assert x_pheno.shape[0] == n_samples + try: + assert y.shape[0] == n_samples + assert x_pheno.shape[0] == n_samples + except: + raise ValueError( + "Inconsistent number of samples between covariates and targets" + ) # assert len(genes) == burdens.shape[1] nan_mask = ~np.isnan(y).squeeze() @@ -1348,6 +1311,9 @@ def regress( gene_df.set_index("id") genes = gene_df.loc[genes, "gene"].str.split(".").apply(lambda x: x[0]) + logger.info(f"Loading saved burdens from {burden_dir}") + genes = pd.Series(np.load(Path(burden_dir) / "genes.npy")) + chunk_size = math.ceil(len(genes) / n_chunks) chunk_start = chunk * chunk_size chunk_end = min(len(genes), chunk_start + chunk_size) @@ -1357,14 +1323,9 @@ def regress( genes = genes.iloc[chunk_start:chunk_end] gene_indices = np.arange(len(genes)) + logger.info(f"Only extracting genes in range {chunk_start, chunk_end}") - if burden_file is not None: - logger.info(f"Loading burdens from {burden_file}") - burdens = zarr.open(burden_file)[:, chunk_start:chunk_end, repeat] - else: - burdens = zarr.open(Path(burden_dir) / "burdens.zarr")[ - :, chunk_start:chunk_end, repeat - ] + burdens = zarr.open(burden_file)[:, chunk_start:chunk_end, 0] if sample_file is not None: burdens = burdens[samples] diff --git a/deeprvat/deeprvat/config.py b/deeprvat/deeprvat/config.py index 7f594855..0d073300 100644 --- a/deeprvat/deeprvat/config.py +++ b/deeprvat/deeprvat/config.py @@ -30,7 +30,7 @@ def cli(): @click.option("--baseline-results", type=click.Path(exists=True), multiple=True) @click.option("--baseline-results-out", type=click.Path()) @click.option("--seed-genes-out", type=click.Path()) -@click.option("--regenie-options", type=str, multiple=True) +# @click.option("--regenie-options", type=str, multiple=True) @click.argument("old_config_file", type=click.Path(exists=True)) @click.argument("new_config_file", type=click.Path()) def update_config( @@ -38,7 +38,7 @@ def update_config( phenotype: Optional[str], baseline_results: Tuple[str], baseline_results_out: Optional[str], - regenie_options: Optional[Tuple[str]], + # regenie_options: Optional[Tuple[str]], seed_genes_out: Optional[str], old_config_file: str, new_config_file: str, @@ -73,15 +73,15 @@ def update_config( with open(old_config_file) as f: config = yaml.safe_load(f) - if regenie_options is not None: - try: - existing_regenie_options = config["regenie"]["step_2"]["options"] - except KeyError: - existing_regenie_options = [] + # if regenie_options is not None: + # try: + # existing_regenie_options = config["regenie"]["step_2"]["options"] + # except KeyError: + # existing_regenie_options = [] - config["regenie"] = config.get("regenie", {}) - config["regenie"]["step2"] = config["regenie"].get("step_2", {}) - config["regenie"]["step_2"]["options"] = existing_regenie_options + list(regenie_options) + # config["regenie"] = config.get("regenie", {}) + # config["regenie"]["step2"] = config["regenie"].get("step_2", {}) + # config["regenie"]["step_2"]["options"] = existing_regenie_options + list(regenie_options) if phenotype is not None: diff --git a/pipelines/association_testing/association_dataset.snakefile b/pipelines/association_testing/association_dataset.snakefile index add59e6f..661d929e 100644 --- a/pipelines/association_testing/association_dataset.snakefile +++ b/pipelines/association_testing/association_dataset.snakefile @@ -11,11 +11,27 @@ rule association_dataset: input: config = '{phenotype}/deeprvat/hpopt_config.yaml' output: - '{phenotype}/deeprvat/association_dataset.pkl' + temp('{phenotype}/deeprvat/association_dataset.pkl') threads: 4 resources: - mem_mb = lambda wildcards, attempt: 32000 * (attempt + 1), - load = 64000 + mem_mb = lambda wildcards, attempt: 32000 * (attempt + 1) + priority: 30 + shell: + 'deeprvat_associate make-dataset ' + + debug + + "--skip-genotypes " + '{input.config} ' + '{output}' + + +rule association_dataset_burdens: + input: + config = f'{phenotypes[0]}/deeprvat/hpopt_config.yaml' + output: + temp('burdens/association_dataset.pkl') + threads: 4 + resources: + mem_mb = lambda wildcards, attempt: 32000 * (attempt + 1) priority: 30 shell: 'deeprvat_associate make-dataset ' diff --git a/pipelines/association_testing/burdens.snakefile b/pipelines/association_testing/burdens.snakefile index 5d901e3c..7c332f2e 100644 --- a/pipelines/association_testing/burdens.snakefile +++ b/pipelines/association_testing/burdens.snakefile @@ -10,7 +10,7 @@ n_burden_chunks = config.get('n_burden_chunks', 1) if not debug_flag else 2 n_avg_chunks = config.get('n_avg_chunks', 40) n_bags = config['training']['n_bags'] if not debug_flag else 3 n_repeats = config['n_repeats'] -model_path = Path(config.get("pretrained_model_path", "pretrained_models")) +model_path = Path("models") if not "cv_exp" in globals(): cv_exp = config.get("cv_exp", False) @@ -20,25 +20,20 @@ config_file_prefix = ( ) - rule average_burdens: input: chunks = [ - (f'{p}/deeprvat/burdens/chunk{c}.' + - ("finished" if p == phenotypes[0] else "linked")) - for p in phenotypes - for c in range(n_burden_chunks) - ] if not cv_exp else '{phenotype}/deeprvat/burdens/merging.finished' + f'burdens/chunk{chunk}.finished' for chunk in range(n_burden_chunks) + ] if not cv_exp else 'burdens/merging.finished' output: - '{phenotype}/deeprvat/burdens/logs/burdens_averaging_{chunk}.finished', + temp('burdens/burdens_averaging_{chunk}.finished'), params: - burdens_in = '{phenotype}/deeprvat/burdens/burdens.zarr', - burdens_out = '{phenotype}/deeprvat/burdens/burdens_average.zarr', + burdens_in = 'burdens/burdens.zarr', + burdens_out = 'burdens/burdens_average.zarr', repeats = lambda wildcards: ''.join([f'--repeats {r} ' for r in range(int(n_repeats))]) threads: 1 resources: - mem_mb = lambda wildcards, attempt: 4098 + (attempt - 1) * 4098, - load = 4000, + mem_mb = lambda wildcards, attempt: 4098 + (attempt - 1) * 4098 priority: 10, shell: ' && '.join([ @@ -71,8 +66,7 @@ rule compute_xy: prefix = '.' threads: 8 resources: - mem_mb = lambda wildcards, attempt: 20480 + (attempt - 1) * 4098, - load = lambda wildcards, attempt: 16000 + (attempt - 1) * 4000 + mem_mb = lambda wildcards, attempt: 20480 + (attempt - 1) * 4098 shell: ' && '.join([ ('deeprvat_associate compute-xy ' @@ -92,29 +86,28 @@ rule compute_burdens: f'{model_path}/repeat_{repeat}/best/bag_{bag}.ckpt' for repeat in range(n_repeats) for bag in range(n_bags) ], - dataset = '{phenotype}/deeprvat/association_dataset.pkl', - data_config = '{phenotype}/deeprvat/hpopt_config.yaml', + dataset = 'burdens/association_dataset.pkl', + data_config = f'{phenotypes[0]}/deeprvat/hpopt_config.yaml', model_config = model_path / 'config.yaml', output: - '{phenotype}/deeprvat/burdens/chunk{chunk}.finished' + temp('burdens/chunk{chunk}.finished') params: prefix = '.' threads: 8 resources: - mem_mb = 2000000, # Using this value will tell our modified lsf.profile not to set a memory resource - load = 8000, + mem_mb = 32000, gpus = 1 shell: ' && '.join([ ('deeprvat_associate compute-burdens ' + debug + - ' --n-chunks '+ str(n_burden_chunks) + ' ' + ' --n-chunks ' + str(n_burden_chunks) + ' ' '--chunk {wildcards.chunk} ' '--dataset-file {input.dataset} ' '{input.data_config} ' '{input.model_config} ' '{input.checkpoints} ' - '{params.prefix}/{wildcards.phenotype}/deeprvat/burdens'), + 'burdens'), 'touch {output}' ]) @@ -128,8 +121,7 @@ rule reverse_models: model_path / "reverse_finished.tmp" threads: 4 resources: - mem_mb = 20480, - load = 20480 + mem_mb = 20480 shell: " && ".join([ ("deeprvat_associate reverse-models " diff --git a/pipelines/association_testing/regress_eval.snakefile b/pipelines/association_testing/regress_eval.snakefile index 55e36a53..5b8c0b3f 100644 --- a/pipelines/association_testing/regress_eval.snakefile +++ b/pipelines/association_testing/regress_eval.snakefile @@ -17,12 +17,12 @@ config_file_prefix = ( "cv_split0/deeprvat/" if cv_exp else "" ) ########### Average regression -rule all_evaluate: - input: - expand("{phenotype}/deeprvat/eval/significant.parquet", - phenotype=phenotypes), - expand("{phenotype}/deeprvat/eval/all_results.parquet", - phenotype=phenotypes), +# rule all_evaluate: +# input: +# expand("{phenotype}/deeprvat/eval/significant.parquet", +# phenotype=phenotypes), +# expand("{phenotype}/deeprvat/eval/all_results.parquet", +# phenotype=phenotypes), rule evaluate: input: @@ -68,15 +68,12 @@ rule combine_regression_chunks: rule regress: input: config = f"{config_file_prefix}{{phenotype}}/deeprvat/hpopt_config.yaml", - chunks = lambda wildcards: ( - [] if wildcards.phenotype == phenotypes[0] - else expand('{{phenotype}}/deeprvat/burdens/chunk{chunk}.linked', - chunk=range(n_burden_chunks)) - ) if not cv_exp else '{phenotype}/deeprvat/burdens/merging.finished', - phenotype_0_chunks = expand( - phenotypes[0] + '/deeprvat/burdens/logs/burdens_averaging_{chunk}.finished', + chunks = expand( + 'burdens/burdens_averaging_{chunk}.finished', chunk=range(n_avg_chunks) ), + x = '{phenotype}/deeprvat/xy/x.zarr', + y = '{phenotype}/deeprvat/xy/y.zarr', output: temp('{phenotype}/deeprvat/average_regression_results/burden_associations_{chunk}.parquet'), threads: 2 @@ -85,8 +82,9 @@ rule regress: # mem_mb = 16000, load = lambda wildcards, attempt: 28000 + (attempt - 1) * 4000 params: - burden_file = f'{phenotypes[0]}/deeprvat/burdens/burdens_average.zarr', - burden_dir = '{phenotype}/deeprvat/burdens', + burden_file = 'burdens/burdens_average.zarr', + xy_dir = "{phenotype}/deeprvat/xy", + # burden_dir = 'burdens', out_dir = '{phenotype}/deeprvat/average_regression_results' shell: 'deeprvat_associate regress ' @@ -94,10 +92,10 @@ rule regress: '--chunk {wildcards.chunk} ' '--n-chunks ' + str(n_regression_chunks) + ' ' '--use-bias ' - '--repeat 0 ' - '--burden-file {params.burden_file} ' + # '--repeat 0 ' + do_scoretest + '{input.config} ' - '{params.burden_dir} ' #TODO make this w/o repeats + "{params.xy_dir} " + # '{params.burden_dir} ' #TODO make this w/o repeats + "{params.burden_file} " '{params.out_dir}' - diff --git a/pipelines/association_testing/regress_eval_regenie.snakefile b/pipelines/association_testing/regress_eval_regenie.snakefile index b35396b9..a389c69f 100644 --- a/pipelines/association_testing/regress_eval_regenie.snakefile +++ b/pipelines/association_testing/regress_eval_regenie.snakefile @@ -10,8 +10,6 @@ phenotypes = list(phenotypes.keys()) if type(phenotypes) == dict else phenotypes n_burden_chunks = config.get('n_burden_chunks', 1) if not debug_flag else 2 -burdens = Path(config["burden_file"]) - regenie_config_step1 = config["regenie"]["step_1"] regenie_config_step2 = config["regenie"]["step_2"] regenie_step1_bsize = regenie_config_step1["bsize"] @@ -230,7 +228,7 @@ rule make_regenie_burdens: input: gene_file = config["data"]["dataset_config"]["rare_embedding"]["config"]["gene_file"], gtf_file = config["gtf_file"], - burdens = burdens, + burdens = 'burdens/burdens_average.zarr', genes = burdens.parent / "genes.npy", samples = burdens.parent / "sample_ids.zarr", datasets = expand("{phenotype}/deeprvat/association_dataset.pkl", diff --git a/pipelines/association_testing_precomputed_burdens_regenie.snakefile b/pipelines/association_testing_precomputed_burdens_regenie.snakefile index 29ee3ea7..9eac0f43 100644 --- a/pipelines/association_testing_precomputed_burdens_regenie.snakefile +++ b/pipelines/association_testing_precomputed_burdens_regenie.snakefile @@ -12,7 +12,7 @@ n_bags = config['training']['n_bags'] if not debug_flag else 3 n_repeats = config['n_repeats'] debug = '--debug ' if debug_flag else '' do_scoretest = '--do-scoretest ' if config.get('do_scoretest', False) else '' -model_path = Path(config.get("pretrained_model_path", "pretrained_models")) +model_path = Path("models") wildcard_constraints: repeat="\d+", @@ -33,4 +33,5 @@ rule all: rule all_association_dataset: input: expand('{phenotype}/deeprvat/association_dataset.pkl', - phenotype=phenotypes) + phenotype=phenotypes), + 'association_dataset_burdens.pkl', diff --git a/pipelines/association_testing_pretrained.snakefile b/pipelines/association_testing_pretrained.snakefile index 7895d27b..9b308682 100644 --- a/pipelines/association_testing_pretrained.snakefile +++ b/pipelines/association_testing_pretrained.snakefile @@ -16,7 +16,7 @@ n_repeats = config['n_repeats'] debug = '--debug ' if debug_flag else '' do_scoretest = '--do-scoretest ' if config.get('do_scoretest', False) else '' tensor_compression_level = config['training'].get('tensor_compression_level', 1) -model_path = Path(config.get("pretrained_model_path", "pretrained_models")) +model_path = Path("models") wildcard_constraints: repeat="\d+", diff --git a/pipelines/association_testing_pretrained_regenie.snakefile b/pipelines/association_testing_pretrained_regenie.snakefile index f3eb0b0e..6cb45318 100644 --- a/pipelines/association_testing_pretrained_regenie.snakefile +++ b/pipelines/association_testing_pretrained_regenie.snakefile @@ -12,7 +12,7 @@ n_bags = config['training']['n_bags'] if not debug_flag else 3 n_repeats = config['n_repeats'] debug = '--debug ' if debug_flag else '' do_scoretest = '--do-scoretest ' if config.get('do_scoretest', False) else '' -model_path = Path(config.get("pretrained_model_path", "pretrained_models")) +model_path = Path("models") wildcard_constraints: repeat="\d+",