From 755efbba9496aa6c015a56a499a8f8212871f85f Mon Sep 17 00:00:00 2001 From: Konstantin Hemker <33329141+konst-int-i@users.noreply.github.com> Date: Fri, 6 Dec 2024 07:18:39 -0500 Subject: [PATCH] Feature/hestbench flexibility (#77) * MODELS: CustomInferenceEncoder fix in factory * EVAL: sensible eval defaults for easier in-pipeline workflow * GIT: ignore .pyc * EVAL: prevent repeat extraction when overwrite=True --- .gitignore | 1 + bench_config/bench_config.yaml | 8 +- docs/source/api.md | 2 +- src/hest/bench/__init__.py | 2 +- src/hest/bench/benchmark.py | 179 ++++++++++-------- .../bench/cpath_model_zoo/inference_models.py | 2 +- tutorials/4-Running-HEST-Benchmark.ipynb | 12 +- 7 files changed, 119 insertions(+), 87 deletions(-) diff --git a/.gitignore b/.gitignore index a03b849..1aad718 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ tests/assets config tests/output_tests HEST/ +*.pyc results atlas diff --git a/bench_config/bench_config.yaml b/bench_config/bench_config.yaml index 42c07c5..3e75f06 100644 --- a/bench_config/bench_config.yaml +++ b/bench_config/bench_config.yaml @@ -1,14 +1,14 @@ # directory containing the data for each task -bench_data_root: 'bench_data' +bench_data_root: 'eval/bench_data' # directory where benchmark results will be dumped -results_dir: 'ST_pred_results' +results_dir: 'eval/ST_pred_results' # directory where the vision embeddings will be dumped -embed_dataroot: 'ST_data_emb' +embed_dataroot: 'eval/ST_data_emb' # directory to the model weights root -weights_root: 'fm_v1' +weights_root: 'eval/fm_v1' # inference parameters batch_size: 128 diff --git a/docs/source/api.md b/docs/source/api.md index 08a69ba..2816f24 100644 --- a/docs/source/api.md +++ b/docs/source/api.md @@ -24,7 +24,7 @@ .. autosummary:: :toctree: generated - benchmark_encoder + benchmark ``` ## HESTData class diff --git a/src/hest/bench/__init__.py b/src/hest/bench/__init__.py index a374b1c..8c1d799 100644 --- a/src/hest/bench/__init__.py +++ b/src/hest/bench/__init__.py @@ -1,2 +1,2 @@ from . import st_dataset -from .benchmark import benchmark_encoder \ No newline at end of file +from .benchmark import benchmark \ No newline at end of file diff --git a/src/hest/bench/benchmark.py b/src/hest/bench/benchmark.py index 0e8fc44..43b14c9 100644 --- a/src/hest/bench/benchmark.py +++ b/src/hest/bench/benchmark.py @@ -4,7 +4,10 @@ import json import os from operator import itemgetter -from typing import Callable, List +from typing import List, Optional, Tuple +from dataclasses import dataclass, asdict, field +from argparse import Namespace + import numpy as np import pandas as pd @@ -31,37 +34,66 @@ save_pkl) from hest.bench.utils.utils import merge_dict, get_current_time -# Generic training settings +# Generic training settings - note that defaults are set in BenchmarkConfig parser = argparse.ArgumentParser(description='Configurations for linear probing') ### optimizer settings ### -parser.add_argument('--seed', type=int, default=1, - help='random seed for reproducible experiment (default: 1)') -parser.add_argument('--overwrite', action='store_true', default=False, +parser.add_argument('--seed', type=int, + help='random seed for reproducible experiment') +parser.add_argument('--overwrite', action='store_true', help='overwrite existing results') parser.add_argument('--bench_data_root', type=str, help='root directory containing all the datasets') parser.add_argument('--embed_dataroot', type=str) parser.add_argument('--weights_root', type=str) -parser.add_argument('--private_weights_root', type=str, default=None) +parser.add_argument('--private_weights_root', type=str) parser.add_argument('--results_dir', type=str) -parser.add_argument('--exp_code', type=str, default=None) +parser.add_argument('--exp_code', type=str) ### specify encoder settings ### -parser.add_argument('--batch_size', type=int, default=128, help='Batch size') -parser.add_argument('--num_workers', type=int, default=1, help='Number of workers for dataloader') +parser.add_argument('--batch_size', type=int, help='Batch size') +parser.add_argument('--num_workers', type=int, help='Number of workers for dataloader') ### specify dataset settings ### -parser.add_argument('--gene_list', type=str, default='var_50genes.json') -parser.add_argument('--method', type=str, default='ridge') -parser.add_argument('--alpha', type=float, default=None) -parser.add_argument('--kfold', action='store_true', default=False) -parser.add_argument('--benchmark_encoders', action='store_true', default=False) -parser.add_argument('--normalize', type=bool, default=True) -parser.add_argument('--dimreduce', type=str, default=None, help='whenever to perform dimensionality reduction before linear probing, can be "PCA" or None') -parser.add_argument('--latent_dim', type=int, default=256, help='dimensionality reduction latent dimension') -parser.add_argument('--encoders', nargs='+', help='All the encoders to benchmark', default=[]) -parser.add_argument('--datasets', nargs='+', help='Datasets from bench_data_root to use during benchmark', default=['*']) -parser.add_argument('--config', type=str, help='Path to a benchmark config file, arguments provided in the config file will overwrite the command line args', default=None) +parser.add_argument('--gene_list', type=str) +parser.add_argument('--method', type=str) +parser.add_argument('--alpha', type=float) +parser.add_argument('--kfold', action='store_true') +parser.add_argument('--benchmark_encoders', action='store_true') +parser.add_argument('--normalize', type=bool) +parser.add_argument('--dimreduce', type=str, help='whenever to perform dimensionality reduction before linear probing, can be "PCA" or None') +parser.add_argument('--latent_dim', type=int, help='dimensionality reduction latent dimension') +parser.add_argument('--encoders', nargs='+', help='All the encoders to benchmark') +parser.add_argument('--datasets', nargs='+', help='Datasets from bench_data_root to use during benchmark') +parser.add_argument('--config', type=str, help='Path to a benchmark config file, arguments provided in the config file will overwrite the command line args') +@dataclass +class BenchmarkConfig: + """ + Dataclass containing default arguments for benchmarking. Note that arguments are overwritten in ``benchmark`` either through: + - CLI arguments + - Function kwargs + - Config file (which paths needs to be specified in the CLI or kwargs) + """ + seed: int = 1 + overwrite: bool = False + bench_data_root: Optional[str] = 'eval/bench_data' + embed_dataroot: Optional[str] = 'eval/ST_data_emb' + weights_root: Optional[str] = 'eval/fm_v1' + results_dir: Optional[str] = 'eval/ST_pred_results' + private_weights_root: Optional[str] = None + exp_code: Optional[str] = None + batch_size: int = 128 + num_workers: int = 1 + gene_list: str = 'var_50genes.json' + method: str = 'ridge' + alpha: Optional[float] = None + kfold: bool = False + benchmark_encoders: bool = False + normalize: bool = True + dimreduce: Optional[str] = "PCA" + latent_dim: int = 256 + encoders: list = field(default_factory=lambda: ['resnet50']) + datasets: list = field(default_factory=lambda: ['IDC']) + config: Optional[str] = None def get_path(path): src = get_path_relative(__file__, '../../../../') @@ -72,7 +104,7 @@ def get_path(path): return new_path -def benchmark_grid(args, device, model_names, datasets: List[str], save_dir, custom_encoder=None): +def benchmark_grid(args, device, model_names, datasets: List[str], save_dir, custom_encoder=None) -> Tuple[list, dict]: """ Execute predict_folds for each encoders and datasets and dump the results in a nested directory structure """ dataset_perfs = [] @@ -80,6 +112,7 @@ def benchmark_grid(args, device, model_names, datasets: List[str], save_dir, cus bench_data_root = os.path.join(get_path(args.bench_data_root), dataset) enc_perfs = [] for model_name in model_names: + logger.info(f'HESTBench task: {dataset}, Encoder: {model_name}') exp_save_dir = os.path.join(save_dir, dataset, model_name) os.makedirs(exp_save_dir, exist_ok=True) enc_results = predict_folds(args, exp_save_dir, model_name, dataset, device, bench_data_root, custom_encoder) @@ -124,6 +157,8 @@ def benchmark_grid(args, device, model_names, datasets: List[str], save_dir, cus with open(os.path.join(save_dir, 'dataset_results.json'), 'w') as f: json.dump({'results': dataset_perfs, 'average': perf_per_enc}, f, sort_keys=True, indent=4) + + return dataset_perfs, perf_per_enc def post_collate_fn(batch): @@ -178,7 +213,7 @@ def get_bench_weights(weights_root, name): else: raise ValueError(f"Please specify the weights path to {name} in {local_ckpt_registry}") -def predict_single_split(train_split, test_split, args, save_dir, dataset_name, model_name, device, bench_data_root, custom_encoder): +def predict_single_split(train_split, test_split, args, save_dir, dataset_name, model_name, device, bench_data_root, custom_encoder, extract_tiles): """ Predict a single split for a single model """ if not os.path.isfile(train_split): @@ -193,10 +228,11 @@ def predict_single_split(train_split, test_split, args, save_dir, dataset_name, os.makedirs(embedding_dir, exist_ok=True) # Embed patches - logger.info(f"Embedding tiles using {model_name} encoder") + logger.info(f"Embedding tiles for {dataset_name} using {model_name} encoder") weights_path = get_bench_weights(args.weights_root, model_name) if model_name == 'custom_encoder': encoder = custom_encoder + args.overwrite = True # always overwrite custom encoders else: encoder: InferenceEncoder = inf_encoder_factory(model_name)(weights_path) precision = encoder.precision @@ -207,20 +243,21 @@ def predict_single_split(train_split, test_split, args, save_dir, dataset_name, tile_h5_path = os.path.join(bench_data_root, split.iloc[i]['patches_path']) assert os.path.isfile(tile_h5_path) embed_path = os.path.join(embedding_dir, f'{sample_id}.h5') - if not os.path.isfile(embed_path) or args.overwrite: - - _ = encoder.eval() - encoder.to(device) - - tile_dataset = H5HESTDataset(tile_h5_path, chunk_size=args.batch_size, img_transform=encoder.eval_transforms) - tile_dataloader = torch.utils.data.DataLoader(tile_dataset, - batch_size=1, - shuffle=False, - num_workers=args.num_workers) - - _ = embed_tiles(tile_dataloader, encoder, embed_path, device, precision) - else: - logger.info(f"Skipping {sample_id} as it already exists") + if extract_tiles: + if not os.path.isfile(embed_path) or args.overwrite: + + _ = encoder.eval() + encoder.to(device) + + tile_dataset = H5HESTDataset(tile_h5_path, chunk_size=args.batch_size, img_transform=encoder.eval_transforms) + tile_dataloader = torch.utils.data.DataLoader(tile_dataset, + batch_size=1, + shuffle=False, + num_workers=args.num_workers) + + _ = embed_tiles(tile_dataloader, encoder, embed_path, device, precision) + else: + logger.info(f"Skipping {sample_id} as it already exists") with open(os.path.join(save_dir, 'config.json'), 'w') as f: @@ -300,28 +337,6 @@ def merge_fold_results(arr): mean_per_split = [d['pearson_mean'] for d in arr] return {"pearson_corrs": aggr_results, "pearson_mean": np.mean(mean_per_split), "pearson_std": np.std(mean_per_split), "mean_per_split": mean_per_split} - - -def benchmark_encoder(encoder: torch.nn.Module, enc_transf, precision: torch.dtype, config_path: str) -> dict: - """ Launch HEST-Benchmark - - Args: - encoder (torch.nn.Module): model to benchmark - enc_transf: evaluation transforms - precision (torch.dtype): inference precision - config_path (str): path to a hest-bench config file - - Returns: - dict: results dictionary - """ - - args = parser.parse_args() - - - args.config = config_path - - - benchmark(args, encoder=encoder, enc_transf=enc_transf, precision=precision) def predict_folds(args, exp_save_dir, model_name, dataset_name, device, bench_data_root, custom_encoder): @@ -339,7 +354,8 @@ def predict_folds(args, exp_save_dir, model_name, dataset_name, device, bench_da test_split = os.path.join(split_dir, f'test_{i}.csv') kfold_save_dir = os.path.join(exp_save_dir, f'split{i}') os.makedirs(kfold_save_dir, exist_ok=True) - linprobe_results = predict_single_split(train_split, test_split, args, kfold_save_dir, dataset_name, model_name, device=device, bench_data_root=bench_data_root, custom_encoder=custom_encoder) + extract_tiles = True if i == 0 else False + linprobe_results = predict_single_split(train_split, test_split, args, kfold_save_dir, dataset_name, model_name, device=device, bench_data_root=bench_data_root, custom_encoder=custom_encoder, extract_tiles=extract_tiles) libprobe_results_arr.append(linprobe_results) @@ -360,16 +376,33 @@ def set_seed(seed): random.seed(seed) -def benchmark(args, encoder, enc_transf, precision): +def benchmark(encoder, enc_transf, precision, cli_args=None, **kwargs) -> Tuple[list, dict]: + + # get default args - overwritten if using CLI, kwargs, or config file + args = Namespace(**asdict(BenchmarkConfig())) + + # Prio 1 - overwrite with CLI args + if cli_args is not None: + for k, v in vars(cli_args).items(): + if v is not None: + print(f"Updating {k} with {v}") + setattr(args, k, v) + + + # Prio 2 - overwrite with kwargs if provided + for k, v in kwargs.items(): + if v is not None: + print(f"Updating {k} with {v}") + setattr(args, k, v) - if args.config is not None: + # Prio 3 - overwrite defaults with config if provided + if args.config is not None: with open(args.config) as stream: config = yaml.safe_load(stream) - for key in config: - if key in args: + if key in args: setattr(args, key, config[key]) - + set_seed(args.seed) logger.info(f'Saving models to {args.weights_root}...') @@ -406,18 +439,16 @@ def benchmark(args, encoder, enc_transf, precision): else: custom_encoder = None - encoders += config['encoders'] + encoders += args.encoders - benchmark_grid(args, device, encoders, datasets, save_dir=save_dir, custom_encoder=custom_encoder) - + dataset_perfs, perf_per_enc = benchmark_grid(args, device, encoders, datasets, save_dir=save_dir, custom_encoder=custom_encoder) + return dataset_perfs, perf_per_enc -if __name__ == '__main__': - args = parser.parse_args() - if args.config is None: - parser.error("Please provide --config") +if __name__ == '__main__': + cli_args = parser.parse_args() + + benchmark(None, None, None, cli_args) - benchmark(args, None, None, None) - \ No newline at end of file diff --git a/src/hest/bench/cpath_model_zoo/inference_models.py b/src/hest/bench/cpath_model_zoo/inference_models.py index 1db6862..08d5c4c 100644 --- a/src/hest/bench/cpath_model_zoo/inference_models.py +++ b/src/hest/bench/cpath_model_zoo/inference_models.py @@ -75,7 +75,7 @@ class CustomInferenceEncoder(InferenceEncoder): def __init__(self, weights_path, name, model, transforms, precision): super().__init__(weights_path) self.model = model - self.transforms = transforms + self.eval_transforms = transforms self.precision = precision def _build(self, weights_path): diff --git a/tutorials/4-Running-HEST-Benchmark.ipynb b/tutorials/4-Running-HEST-Benchmark.ipynb index 7462ca8..409cc09 100644 --- a/tutorials/4-Running-HEST-Benchmark.ipynb +++ b/tutorials/4-Running-HEST-Benchmark.ipynb @@ -4,12 +4,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Step-by-step instructions to run HEST-Benchmark\n", + "\n" ] }, { @@ -30,7 +30,7 @@ "| Task 6 | READ | 4 | Visium | ZEN36, ZEN40, ZEN48, ZEN49 |\n", "| Task 7 | ccRCC | 24 | Visium | INT1~INT24 |\n", "| Task 8 | HCC | 2 | Visium | NCBI642, NCBI643 |\n", - "| Task 9 | LUAD | 2 | Xenium | TENX118, TENX141 |\n", + "| Task 9 | LUNG | 2 | Xenium | TENX118, TENX141 |\n", "| Task 10 | IDC-LymphNode | 4 | Visium | NCBI681, NCBI682, NCBI683, NCBI684 |\n", "\n" ] @@ -102,7 +102,7 @@ "metadata": {}, "outputs": [], "source": [ - "from hest.bench import benchmark_encoder\n", + "from hest.bench import benchmark\n", "import torch\n", "\n", "PATH_TO_CONFIG = .. # path to `bench_config.yaml`\n", @@ -110,11 +110,11 @@ "model_transforms = .. # transforms to apply during inference (torchvision.transforms.Compose)\n", "precision = torch.float32\n", "\n", - "benchmark_encoder( \n", + "benchmark( \n", " model, \n", " model_transforms,\n", " precision,\n", - " PATH_TO_CONFIG\n", + " config=PATH_TO_CONFIG, \n", ")" ] }