Skip to content

Commit

Permalink
Add support for best checkpoints + update download_models (#12)
Browse files Browse the repository at this point in the history
* Add support for best checkpoints

* Update release tag of TEM unmyelinated axon seg model

* Add myelin cutter model for download
  • Loading branch information
hermancollin authored Mar 13, 2024
1 parent 1c369ff commit f33f43b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
9 changes: 7 additions & 2 deletions download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from requests.packages.urllib3.util import Retry

MODELS = {
'model_seg_unmyelinated_tem' : {
'url': 'https://github.com/axondeepseg/model_seg_unmyelinated_tem/releases/download/v1.0.0/model_seg_unmyelinated_tem.zip',
'model_seg_unmyelinated_sickkids_tem_best' : {
'url': 'https://github.com/axondeepseg/model_seg_unmyelinated_tem/releases/download/v1.1.0/model_seg_unmyelinated_sickkids_tem_best.zip',
'description': 'Unmyelinated axon segmentation (1-class)',
'contrasts': ['TEM'],
},
Expand All @@ -24,6 +24,11 @@
'description': 'Axon and myelin segmentation on Toluidine Blue stained BF images (rabbit)',
'contrasts': ['BF'],
},
'model_myelin_cutter': {
'url': 'https://github.com/axondeepseg/model_postprocess_touching_myelin/releases/download/r20240313/model_myelin_cutter.zip',
'description': 'Postprocessing myelin cutter model to delineate touching myelin. To use directly on an axonmyelin segmentation mask.',
'contrasts': ['any'],
}
}


Expand Down
9 changes: 8 additions & 1 deletion nn_axondeepseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def get_parser():
parser.add_argument('--path-out', help='Path to output directory.', required=True)
parser.add_argument('--path-model', default=None,
help='Path to the model directory. This folder should contain individual folders like fold_0, fold_1, etc.',)
parser.add_argument('--use-best', action='store_true', default=False,
help='Use the best checkpoints instead of the final ones. Default: False')
parser.add_argument('--use-gpu', action='store_true', default=False,
help='Use GPU for inference. Default: False')
return parser
Expand Down Expand Up @@ -90,7 +92,12 @@ def main():
)
logger.info('Running inference on device: {}'.format(predictor.device))
# initialize network architecture, load checkpoint
predictor.initialize_from_trained_model_folder(path_model, use_folds=None)
checkpoint_name = 'checkpoint_final.pth' if not args.use_best else 'checkpoint_best.pth'
predictor.initialize_from_trained_model_folder(
path_model,
use_folds=None,
checkpoint_name=checkpoint_name
)
logger.info('Model loaded successfully.')


Expand Down

0 comments on commit f33f43b

Please sign in to comment.