Skip to content

Commit

Permalink
Feature/hestbench flexibility (#77)
Browse files Browse the repository at this point in the history
* MODELS: CustomInferenceEncoder fix in factory

* EVAL: sensible eval defaults for easier in-pipeline workflow

* GIT: ignore .pyc

* EVAL: prevent repeat extraction when overwrite=True
  • Loading branch information
konst-int-i authored Dec 6, 2024
1 parent c8bef8c commit 755efbb
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 87 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ tests/assets
config
tests/output_tests
HEST/
*.pyc

results
atlas
Expand Down
8 changes: 4 additions & 4 deletions bench_config/bench_config.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# directory containing the data for each task
bench_data_root: 'bench_data'
bench_data_root: 'eval/bench_data'

# directory where benchmark results will be dumped
results_dir: 'ST_pred_results'
results_dir: 'eval/ST_pred_results'

# directory where the vision embeddings will be dumped
embed_dataroot: 'ST_data_emb'
embed_dataroot: 'eval/ST_data_emb'

# directory to the model weights root
weights_root: 'fm_v1'
weights_root: 'eval/fm_v1'

# inference parameters
batch_size: 128
Expand Down
2 changes: 1 addition & 1 deletion docs/source/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
.. autosummary::
:toctree: generated
benchmark_encoder
benchmark
```

## HESTData class
Expand Down
2 changes: 1 addition & 1 deletion src/hest/bench/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from . import st_dataset
from .benchmark import benchmark_encoder
from .benchmark import benchmark
179 changes: 105 additions & 74 deletions src/hest/bench/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import json
import os
from operator import itemgetter
from typing import Callable, List
from typing import List, Optional, Tuple
from dataclasses import dataclass, asdict, field
from argparse import Namespace


import numpy as np
import pandas as pd
Expand All @@ -31,37 +34,66 @@
save_pkl)
from hest.bench.utils.utils import merge_dict, get_current_time

# Generic training settings
# Generic training settings - note that defaults are set in BenchmarkConfig
parser = argparse.ArgumentParser(description='Configurations for linear probing')
### optimizer settings ###
parser.add_argument('--seed', type=int, default=1,
help='random seed for reproducible experiment (default: 1)')
parser.add_argument('--overwrite', action='store_true', default=False,
parser.add_argument('--seed', type=int,
help='random seed for reproducible experiment')
parser.add_argument('--overwrite', action='store_true',
help='overwrite existing results')
parser.add_argument('--bench_data_root', type=str, help='root directory containing all the datasets')
parser.add_argument('--embed_dataroot', type=str)
parser.add_argument('--weights_root', type=str)
parser.add_argument('--private_weights_root', type=str, default=None)
parser.add_argument('--private_weights_root', type=str)
parser.add_argument('--results_dir', type=str)
parser.add_argument('--exp_code', type=str, default=None)
parser.add_argument('--exp_code', type=str)

### specify encoder settings ###
parser.add_argument('--batch_size', type=int, default=128, help='Batch size')
parser.add_argument('--num_workers', type=int, default=1, help='Number of workers for dataloader')
parser.add_argument('--batch_size', type=int, help='Batch size')
parser.add_argument('--num_workers', type=int, help='Number of workers for dataloader')

### specify dataset settings ###
parser.add_argument('--gene_list', type=str, default='var_50genes.json')
parser.add_argument('--method', type=str, default='ridge')
parser.add_argument('--alpha', type=float, default=None)
parser.add_argument('--kfold', action='store_true', default=False)
parser.add_argument('--benchmark_encoders', action='store_true', default=False)
parser.add_argument('--normalize', type=bool, default=True)
parser.add_argument('--dimreduce', type=str, default=None, help='whenever to perform dimensionality reduction before linear probing, can be "PCA" or None')
parser.add_argument('--latent_dim', type=int, default=256, help='dimensionality reduction latent dimension')
parser.add_argument('--encoders', nargs='+', help='All the encoders to benchmark', default=[])
parser.add_argument('--datasets', nargs='+', help='Datasets from bench_data_root to use during benchmark', default=['*'])
parser.add_argument('--config', type=str, help='Path to a benchmark config file, arguments provided in the config file will overwrite the command line args', default=None)
parser.add_argument('--gene_list', type=str)
parser.add_argument('--method', type=str)
parser.add_argument('--alpha', type=float)
parser.add_argument('--kfold', action='store_true')
parser.add_argument('--benchmark_encoders', action='store_true')
parser.add_argument('--normalize', type=bool)
parser.add_argument('--dimreduce', type=str, help='whenever to perform dimensionality reduction before linear probing, can be "PCA" or None')
parser.add_argument('--latent_dim', type=int, help='dimensionality reduction latent dimension')
parser.add_argument('--encoders', nargs='+', help='All the encoders to benchmark')
parser.add_argument('--datasets', nargs='+', help='Datasets from bench_data_root to use during benchmark')
parser.add_argument('--config', type=str, help='Path to a benchmark config file, arguments provided in the config file will overwrite the command line args')

@dataclass
class BenchmarkConfig:
"""
Dataclass containing default arguments for benchmarking. Note that arguments are overwritten in ``benchmark`` either through:
- CLI arguments
- Function kwargs
- Config file (which paths needs to be specified in the CLI or kwargs)
"""
seed: int = 1
overwrite: bool = False
bench_data_root: Optional[str] = 'eval/bench_data'
embed_dataroot: Optional[str] = 'eval/ST_data_emb'
weights_root: Optional[str] = 'eval/fm_v1'
results_dir: Optional[str] = 'eval/ST_pred_results'
private_weights_root: Optional[str] = None
exp_code: Optional[str] = None
batch_size: int = 128
num_workers: int = 1
gene_list: str = 'var_50genes.json'
method: str = 'ridge'
alpha: Optional[float] = None
kfold: bool = False
benchmark_encoders: bool = False
normalize: bool = True
dimreduce: Optional[str] = "PCA"
latent_dim: int = 256
encoders: list = field(default_factory=lambda: ['resnet50'])
datasets: list = field(default_factory=lambda: ['IDC'])
config: Optional[str] = None

def get_path(path):
src = get_path_relative(__file__, '../../../../')
Expand All @@ -72,14 +104,15 @@ def get_path(path):
return new_path


def benchmark_grid(args, device, model_names, datasets: List[str], save_dir, custom_encoder=None):
def benchmark_grid(args, device, model_names, datasets: List[str], save_dir, custom_encoder=None) -> Tuple[list, dict]:
""" Execute predict_folds for each encoders and datasets and dump the results in a nested directory structure """

dataset_perfs = []
for dataset in datasets:
bench_data_root = os.path.join(get_path(args.bench_data_root), dataset)
enc_perfs = []
for model_name in model_names:
logger.info(f'HESTBench task: {dataset}, Encoder: {model_name}')
exp_save_dir = os.path.join(save_dir, dataset, model_name)
os.makedirs(exp_save_dir, exist_ok=True)
enc_results = predict_folds(args, exp_save_dir, model_name, dataset, device, bench_data_root, custom_encoder)
Expand Down Expand Up @@ -124,6 +157,8 @@ def benchmark_grid(args, device, model_names, datasets: List[str], save_dir, cus

with open(os.path.join(save_dir, 'dataset_results.json'), 'w') as f:
json.dump({'results': dataset_perfs, 'average': perf_per_enc}, f, sort_keys=True, indent=4)

return dataset_perfs, perf_per_enc


def post_collate_fn(batch):
Expand Down Expand Up @@ -178,7 +213,7 @@ def get_bench_weights(weights_root, name):
else:
raise ValueError(f"Please specify the weights path to {name} in {local_ckpt_registry}")

def predict_single_split(train_split, test_split, args, save_dir, dataset_name, model_name, device, bench_data_root, custom_encoder):
def predict_single_split(train_split, test_split, args, save_dir, dataset_name, model_name, device, bench_data_root, custom_encoder, extract_tiles):
""" Predict a single split for a single model """

if not os.path.isfile(train_split):
Expand All @@ -193,10 +228,11 @@ def predict_single_split(train_split, test_split, args, save_dir, dataset_name,
os.makedirs(embedding_dir, exist_ok=True)

# Embed patches
logger.info(f"Embedding tiles using {model_name} encoder")
logger.info(f"Embedding tiles for {dataset_name} using {model_name} encoder")
weights_path = get_bench_weights(args.weights_root, model_name)
if model_name == 'custom_encoder':
encoder = custom_encoder
args.overwrite = True # always overwrite custom encoders
else:
encoder: InferenceEncoder = inf_encoder_factory(model_name)(weights_path)
precision = encoder.precision
Expand All @@ -207,20 +243,21 @@ def predict_single_split(train_split, test_split, args, save_dir, dataset_name,
tile_h5_path = os.path.join(bench_data_root, split.iloc[i]['patches_path'])
assert os.path.isfile(tile_h5_path)
embed_path = os.path.join(embedding_dir, f'{sample_id}.h5')
if not os.path.isfile(embed_path) or args.overwrite:

_ = encoder.eval()
encoder.to(device)

tile_dataset = H5HESTDataset(tile_h5_path, chunk_size=args.batch_size, img_transform=encoder.eval_transforms)
tile_dataloader = torch.utils.data.DataLoader(tile_dataset,
batch_size=1,
shuffle=False,
num_workers=args.num_workers)

_ = embed_tiles(tile_dataloader, encoder, embed_path, device, precision)
else:
logger.info(f"Skipping {sample_id} as it already exists")
if extract_tiles:
if not os.path.isfile(embed_path) or args.overwrite:

_ = encoder.eval()
encoder.to(device)

tile_dataset = H5HESTDataset(tile_h5_path, chunk_size=args.batch_size, img_transform=encoder.eval_transforms)
tile_dataloader = torch.utils.data.DataLoader(tile_dataset,
batch_size=1,
shuffle=False,
num_workers=args.num_workers)

_ = embed_tiles(tile_dataloader, encoder, embed_path, device, precision)
else:
logger.info(f"Skipping {sample_id} as it already exists")


with open(os.path.join(save_dir, 'config.json'), 'w') as f:
Expand Down Expand Up @@ -300,28 +337,6 @@ def merge_fold_results(arr):
mean_per_split = [d['pearson_mean'] for d in arr]

return {"pearson_corrs": aggr_results, "pearson_mean": np.mean(mean_per_split), "pearson_std": np.std(mean_per_split), "mean_per_split": mean_per_split}


def benchmark_encoder(encoder: torch.nn.Module, enc_transf, precision: torch.dtype, config_path: str) -> dict:
""" Launch HEST-Benchmark
Args:
encoder (torch.nn.Module): model to benchmark
enc_transf: evaluation transforms
precision (torch.dtype): inference precision
config_path (str): path to a hest-bench config file
Returns:
dict: results dictionary
"""

args = parser.parse_args()


args.config = config_path


benchmark(args, encoder=encoder, enc_transf=enc_transf, precision=precision)


def predict_folds(args, exp_save_dir, model_name, dataset_name, device, bench_data_root, custom_encoder):
Expand All @@ -339,7 +354,8 @@ def predict_folds(args, exp_save_dir, model_name, dataset_name, device, bench_da
test_split = os.path.join(split_dir, f'test_{i}.csv')
kfold_save_dir = os.path.join(exp_save_dir, f'split{i}')
os.makedirs(kfold_save_dir, exist_ok=True)
linprobe_results = predict_single_split(train_split, test_split, args, kfold_save_dir, dataset_name, model_name, device=device, bench_data_root=bench_data_root, custom_encoder=custom_encoder)
extract_tiles = True if i == 0 else False
linprobe_results = predict_single_split(train_split, test_split, args, kfold_save_dir, dataset_name, model_name, device=device, bench_data_root=bench_data_root, custom_encoder=custom_encoder, extract_tiles=extract_tiles)
libprobe_results_arr.append(linprobe_results)


Expand All @@ -360,16 +376,33 @@ def set_seed(seed):
random.seed(seed)


def benchmark(args, encoder, enc_transf, precision):
def benchmark(encoder, enc_transf, precision, cli_args=None, **kwargs) -> Tuple[list, dict]:

# get default args - overwritten if using CLI, kwargs, or config file
args = Namespace(**asdict(BenchmarkConfig()))

# Prio 1 - overwrite with CLI args
if cli_args is not None:
for k, v in vars(cli_args).items():
if v is not None:
print(f"Updating {k} with {v}")
setattr(args, k, v)


# Prio 2 - overwrite with kwargs if provided
for k, v in kwargs.items():
if v is not None:
print(f"Updating {k} with {v}")
setattr(args, k, v)

if args.config is not None:
# Prio 3 - overwrite defaults with config if provided
if args.config is not None:
with open(args.config) as stream:
config = yaml.safe_load(stream)

for key in config:
if key in args:
if key in args:
setattr(args, key, config[key])

set_seed(args.seed)

logger.info(f'Saving models to {args.weights_root}...')
Expand Down Expand Up @@ -406,18 +439,16 @@ def benchmark(args, encoder, enc_transf, precision):
else:
custom_encoder = None

encoders += config['encoders']
encoders += args.encoders

benchmark_grid(args, device, encoders, datasets, save_dir=save_dir, custom_encoder=custom_encoder)

dataset_perfs, perf_per_enc = benchmark_grid(args, device, encoders, datasets, save_dir=save_dir, custom_encoder=custom_encoder)

return dataset_perfs, perf_per_enc


if __name__ == '__main__':
args = parser.parse_args()

if args.config is None:
parser.error("Please provide --config")
if __name__ == '__main__':
cli_args = parser.parse_args()

benchmark(None, None, None, cli_args)

benchmark(args, None, None, None)

2 changes: 1 addition & 1 deletion src/hest/bench/cpath_model_zoo/inference_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class CustomInferenceEncoder(InferenceEncoder):
def __init__(self, weights_path, name, model, transforms, precision):
super().__init__(weights_path)
self.model = model
self.transforms = transforms
self.eval_transforms = transforms
self.precision = precision

def _build(self, weights_path):
Expand Down
12 changes: 6 additions & 6 deletions tutorials/4-Running-HEST-Benchmark.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step-by-step instructions to run HEST-Benchmark\n",
"<!-- ## Step-by-step instructions to run HEST-Benchmark\n",
"\n",
"This tutorial will guide you to:\n",
"\n",
"- **Reproduce** HEST-Benchmark results provided in the paper (Random Forest regression and Ridge regression models)\n",
"- Benchmark your **own** model\n"
"- Benchmark your **own** model -->\n"
]
},
{
Expand All @@ -30,7 +30,7 @@
"| Task 6 | READ | 4 | Visium | ZEN36, ZEN40, ZEN48, ZEN49 |\n",
"| Task 7 | ccRCC | 24 | Visium | INT1~INT24 |\n",
"| Task 8 | HCC | 2 | Visium | NCBI642, NCBI643 |\n",
"| Task 9 | LUAD | 2 | Xenium | TENX118, TENX141 |\n",
"| Task 9 | LUNG | 2 | Xenium | TENX118, TENX141 |\n",
"| Task 10 | IDC-LymphNode | 4 | Visium | NCBI681, NCBI682, NCBI683, NCBI684 |\n",
"\n"
]
Expand Down Expand Up @@ -102,19 +102,19 @@
"metadata": {},
"outputs": [],
"source": [
"from hest.bench import benchmark_encoder\n",
"from hest.bench import benchmark\n",
"import torch\n",
"\n",
"PATH_TO_CONFIG = .. # path to `bench_config.yaml`\n",
"model = .. # PyTorch model (torch.nn.Module)\n",
"model_transforms = .. # transforms to apply during inference (torchvision.transforms.Compose)\n",
"precision = torch.float32\n",
"\n",
"benchmark_encoder( \n",
"benchmark( \n",
" model, \n",
" model_transforms,\n",
" precision,\n",
" PATH_TO_CONFIG\n",
" config=PATH_TO_CONFIG, \n",
")"
]
}
Expand Down

0 comments on commit 755efbb

Please sign in to comment.