diff --git a/docs/source/api.md b/docs/source/api.md index 08a69ba..2816f24 100644 --- a/docs/source/api.md +++ b/docs/source/api.md @@ -24,7 +24,7 @@ .. autosummary:: :toctree: generated - benchmark_encoder + benchmark ``` ## HESTData class diff --git a/src/hest/bench/__init__.py b/src/hest/bench/__init__.py index a374b1c..8c1d799 100644 --- a/src/hest/bench/__init__.py +++ b/src/hest/bench/__init__.py @@ -1,2 +1,2 @@ from . import st_dataset -from .benchmark import benchmark_encoder \ No newline at end of file +from .benchmark import benchmark \ No newline at end of file diff --git a/src/hest/bench/benchmark.py b/src/hest/bench/benchmark.py index 238e52e..43b14c9 100644 --- a/src/hest/bench/benchmark.py +++ b/src/hest/bench/benchmark.py @@ -4,7 +4,7 @@ import json import os from operator import itemgetter -from typing import List, Optional, +from typing import List, Optional, Tuple from dataclasses import dataclass, asdict, field from argparse import Namespace @@ -104,7 +104,7 @@ def get_path(path): return new_path -def benchmark_grid(args, device, model_names, datasets: List[str], save_dir, custom_encoder=None): +def benchmark_grid(args, device, model_names, datasets: List[str], save_dir, custom_encoder=None) -> Tuple[list, dict]: """ Execute predict_folds for each encoders and datasets and dump the results in a nested directory structure """ dataset_perfs = [] @@ -112,6 +112,7 @@ def benchmark_grid(args, device, model_names, datasets: List[str], save_dir, cus bench_data_root = os.path.join(get_path(args.bench_data_root), dataset) enc_perfs = [] for model_name in model_names: + logger.info(f'HESTBench task: {dataset}, Encoder: {model_name}') exp_save_dir = os.path.join(save_dir, dataset, model_name) os.makedirs(exp_save_dir, exist_ok=True) enc_results = predict_folds(args, exp_save_dir, model_name, dataset, device, bench_data_root, custom_encoder) @@ -156,6 +157,8 @@ def benchmark_grid(args, device, model_names, datasets: List[str], save_dir, cus with open(os.path.join(save_dir, 'dataset_results.json'), 'w') as f: json.dump({'results': dataset_perfs, 'average': perf_per_enc}, f, sort_keys=True, indent=4) + + return dataset_perfs, perf_per_enc def post_collate_fn(batch): @@ -210,7 +213,7 @@ def get_bench_weights(weights_root, name): else: raise ValueError(f"Please specify the weights path to {name} in {local_ckpt_registry}") -def predict_single_split(train_split, test_split, args, save_dir, dataset_name, model_name, device, bench_data_root, custom_encoder): +def predict_single_split(train_split, test_split, args, save_dir, dataset_name, model_name, device, bench_data_root, custom_encoder, extract_tiles): """ Predict a single split for a single model """ if not os.path.isfile(train_split): @@ -225,10 +228,11 @@ def predict_single_split(train_split, test_split, args, save_dir, dataset_name, os.makedirs(embedding_dir, exist_ok=True) # Embed patches - logger.info(f"Embedding tiles using {model_name} encoder") + logger.info(f"Embedding tiles for {dataset_name} using {model_name} encoder") weights_path = get_bench_weights(args.weights_root, model_name) if model_name == 'custom_encoder': encoder = custom_encoder + args.overwrite = True # always overwrite custom encoders else: encoder: InferenceEncoder = inf_encoder_factory(model_name)(weights_path) precision = encoder.precision @@ -239,20 +243,21 @@ def predict_single_split(train_split, test_split, args, save_dir, dataset_name, tile_h5_path = os.path.join(bench_data_root, split.iloc[i]['patches_path']) assert os.path.isfile(tile_h5_path) embed_path = os.path.join(embedding_dir, f'{sample_id}.h5') - if not os.path.isfile(embed_path) or args.overwrite: - - _ = encoder.eval() - encoder.to(device) - - tile_dataset = H5HESTDataset(tile_h5_path, chunk_size=args.batch_size, img_transform=encoder.eval_transforms) - tile_dataloader = torch.utils.data.DataLoader(tile_dataset, - batch_size=1, - shuffle=False, - num_workers=args.num_workers) - - _ = embed_tiles(tile_dataloader, encoder, embed_path, device, precision) - else: - logger.info(f"Skipping {sample_id} as it already exists") + if extract_tiles: + if not os.path.isfile(embed_path) or args.overwrite: + + _ = encoder.eval() + encoder.to(device) + + tile_dataset = H5HESTDataset(tile_h5_path, chunk_size=args.batch_size, img_transform=encoder.eval_transforms) + tile_dataloader = torch.utils.data.DataLoader(tile_dataset, + batch_size=1, + shuffle=False, + num_workers=args.num_workers) + + _ = embed_tiles(tile_dataloader, encoder, embed_path, device, precision) + else: + logger.info(f"Skipping {sample_id} as it already exists") with open(os.path.join(save_dir, 'config.json'), 'w') as f: @@ -349,7 +354,8 @@ def predict_folds(args, exp_save_dir, model_name, dataset_name, device, bench_da test_split = os.path.join(split_dir, f'test_{i}.csv') kfold_save_dir = os.path.join(exp_save_dir, f'split{i}') os.makedirs(kfold_save_dir, exist_ok=True) - linprobe_results = predict_single_split(train_split, test_split, args, kfold_save_dir, dataset_name, model_name, device=device, bench_data_root=bench_data_root, custom_encoder=custom_encoder) + extract_tiles = True if i == 0 else False + linprobe_results = predict_single_split(train_split, test_split, args, kfold_save_dir, dataset_name, model_name, device=device, bench_data_root=bench_data_root, custom_encoder=custom_encoder, extract_tiles=extract_tiles) libprobe_results_arr.append(linprobe_results) @@ -370,7 +376,7 @@ def set_seed(seed): random.seed(seed) -def benchmark(encoder, enc_transf, precision, cli_args=None, **kwargs): +def benchmark(encoder, enc_transf, precision, cli_args=None, **kwargs) -> Tuple[list, dict]: # get default args - overwritten if using CLI, kwargs, or config file args = Namespace(**asdict(BenchmarkConfig())) @@ -435,8 +441,9 @@ def benchmark(encoder, enc_transf, precision, cli_args=None, **kwargs): encoders += args.encoders - benchmark_grid(args, device, encoders, datasets, save_dir=save_dir, custom_encoder=custom_encoder) - + dataset_perfs, perf_per_enc = benchmark_grid(args, device, encoders, datasets, save_dir=save_dir, custom_encoder=custom_encoder) + + return dataset_perfs, perf_per_enc