diff --git a/download_models.py b/download_models.py index 9390b98..38d044b 100644 --- a/download_models.py +++ b/download_models.py @@ -14,8 +14,8 @@ from requests.packages.urllib3.util import Retry MODELS = { - 'model_seg_unmyelinated_tem' : { - 'url': 'https://github.com/axondeepseg/model_seg_unmyelinated_tem/releases/download/v1.0.0/model_seg_unmyelinated_tem.zip', + 'model_seg_unmyelinated_sickkids_tem_best' : { + 'url': 'https://github.com/axondeepseg/model_seg_unmyelinated_tem/releases/download/v1.1.0/model_seg_unmyelinated_sickkids_tem_best.zip', 'description': 'Unmyelinated axon segmentation (1-class)', 'contrasts': ['TEM'], }, @@ -24,6 +24,11 @@ 'description': 'Axon and myelin segmentation on Toluidine Blue stained BF images (rabbit)', 'contrasts': ['BF'], }, + 'model_myelin_cutter': { + 'url': 'https://github.com/axondeepseg/model_postprocess_touching_myelin/releases/download/r20240313/model_myelin_cutter.zip', + 'description': 'Postprocessing myelin cutter model to delineate touching myelin. To use directly on an axonmyelin segmentation mask.', + 'contrasts': ['any'], + } } diff --git a/nn_axondeepseg.py b/nn_axondeepseg.py index 3e9c6b9..8542b8a 100644 --- a/nn_axondeepseg.py +++ b/nn_axondeepseg.py @@ -37,6 +37,8 @@ def get_parser(): parser.add_argument('--path-out', help='Path to output directory.', required=True) parser.add_argument('--path-model', default=None, help='Path to the model directory. This folder should contain individual folders like fold_0, fold_1, etc.',) + parser.add_argument('--use-best', action='store_true', default=False, + help='Use the best checkpoints instead of the final ones. Default: False') parser.add_argument('--use-gpu', action='store_true', default=False, help='Use GPU for inference. Default: False') return parser @@ -90,7 +92,12 @@ def main(): ) logger.info('Running inference on device: {}'.format(predictor.device)) # initialize network architecture, load checkpoint - predictor.initialize_from_trained_model_folder(path_model, use_folds=None) + checkpoint_name = 'checkpoint_final.pth' if not args.use_best else 'checkpoint_best.pth' + predictor.initialize_from_trained_model_folder( + path_model, + use_folds=None, + checkpoint_name=checkpoint_name + ) logger.info('Model loaded successfully.')