Skip to content

Latest commit

 

History

History
72 lines (53 loc) · 3.03 KB

ensemble.md

File metadata and controls

72 lines (53 loc) · 3.03 KB

Model Ensemble

To achieve robust predictions for unseen data, the Auto3dSeg provides a model ensemble module to summarize predictions from various trained models. The module firstly ranks checkpoints of different algorithms based on validation accuracy in each fold of N-fold cross-validation, picks the top-M algorithms from each fold, and creates ensemble predictions using MN checkpoints. The default ensemble algorithm averages the probability maps (from softmax/sigmoid activations) of candidate predictions and generates the final outputs.

How to Run Model Ensemble Independently

The following Python script shows how to ensemble predictions from various algorithms using the Python class AlgoEnsembleBuilder.

import os
from monai.apps.auto3dseg import (
    AlgoEnsembleBestN,
    AlgoEnsembleBuilder,
    import_bundle_algo_history,
)
from monai.utils.enums import AlgoKeys

# Assuming you have already trained the models


work_dir = <your_work_dir>  # the algorithm working directory generated by AlgoGen/BundleGen
input_cfg = <your_task_input_file>  # path to the task input YAML file created by the users

history = import_bundle_algo_history(work_dir, only_trained=True)

## model ensemble
n_best = 1
builder = AlgoEnsembleBuilder(history, input_cfg)
builder.set_ensemble_method(AlgoEnsembleBestN(n_best=n_best))
ensemble = builder.get_ensemble()
pred = ensemble()
print("ensemble picked the following best {0:d}:".format(n_best))
for algo in ensemble.get_algo_ensemble():
    print(algo[AlgoKeys.ID])

Customization

Auto3DSeg also provides the API for users to use their customized ensemble algorithm as shown in this notebook.

Essential Component for General Algorithm/Model Ensemble

The essential component for the model ensemble is the "infer()" function in the "InferClass" class of the script "scripts/infer.py". After class initialization of the "InferClass", "infer()" takes image file names as input, and outputs multi-channel probability maps. And the "infer.py" of different algorithms is located inside their bundle templates. In general, the ensemble module would work for any algorithm as long as the "infer()" function is provided with the proper setup.

class InferClass:
    ...
    @torch.no_grad()
    def infer(self, image_file):
        self.model.eval()

        batch_data = self.infer_transforms(image_file)
        batch_data = list_data_collate([batch_data])
        infer_image = batch_data["image"].to(self.device)

        with torch.cuda.amp.autocast():
            batch_data["pred"] = sliding_window_inference(
                infer_image,
                self.patch_size_valid,
                self.num_sw_batch_size,
                self.model,
                mode="gaussian",
                overlap=self.overlap_ratio,
            )

        batch_data = [self.post_transforms(i) for i in decollate_batch(batch_data)]

        return batch_data[0]["pred"]
	...