To achieve robust predictions for unseen data, the Auto3dSeg provides a model ensemble module to summarize predictions from various trained models. The module firstly ranks checkpoints of different algorithms based on validation accuracy in each fold of N
-fold cross-validation, picks the top-M
algorithms from each fold, and creates ensemble predictions using MN
checkpoints. The default ensemble algorithm averages the probability maps (from softmax/sigmoid activations) of candidate predictions and generates the final outputs.
The following Python script shows how to ensemble predictions from various algorithms using the Python class AlgoEnsembleBuilder.
import os
from monai.apps.auto3dseg import (
AlgoEnsembleBestN,
AlgoEnsembleBuilder,
import_bundle_algo_history,
)
from monai.utils.enums import AlgoKeys
# Assuming you have already trained the models
work_dir = <your_work_dir> # the algorithm working directory generated by AlgoGen/BundleGen
input_cfg = <your_task_input_file> # path to the task input YAML file created by the users
history = import_bundle_algo_history(work_dir, only_trained=True)
## model ensemble
n_best = 1
builder = AlgoEnsembleBuilder(history, input_cfg)
builder.set_ensemble_method(AlgoEnsembleBestN(n_best=n_best))
ensemble = builder.get_ensemble()
pred = ensemble()
print("ensemble picked the following best {0:d}:".format(n_best))
for algo in ensemble.get_algo_ensemble():
print(algo[AlgoKeys.ID])
Auto3DSeg also provides the API for users to use their customized ensemble algorithm as shown in this notebook.
The essential component for the model ensemble is the "infer()" function in the "InferClass" class of the script "scripts/infer.py". After class initialization of the "InferClass", "infer()" takes image file names as input, and outputs multi-channel probability maps. And the "infer.py" of different algorithms is located inside their bundle templates. In general, the ensemble module would work for any algorithm as long as the "infer()" function is provided with the proper setup.
class InferClass:
...
@torch.no_grad()
def infer(self, image_file):
self.model.eval()
batch_data = self.infer_transforms(image_file)
batch_data = list_data_collate([batch_data])
infer_image = batch_data["image"].to(self.device)
with torch.cuda.amp.autocast():
batch_data["pred"] = sliding_window_inference(
infer_image,
self.patch_size_valid,
self.num_sw_batch_size,
self.model,
mode="gaussian",
overlap=self.overlap_ratio,
)
batch_data = [self.post_transforms(i) for i in decollate_batch(batch_data)]
return batch_data[0]["pred"]
...