diff --git a/examples/retinanet_inference_example.py b/examples/retinanet_inference_example.py index 4d63c9f..87f19d9 100644 --- a/examples/retinanet_inference_example.py +++ b/examples/retinanet_inference_example.py @@ -7,7 +7,6 @@ import yolk from yolk.parser import parse_args - def main(args=None): if args is None: args = sys.argv[1:] diff --git a/yolk/parser.py b/yolk/parser.py index 250d71d..ad61396 100644 --- a/yolk/parser.py +++ b/yolk/parser.py @@ -1,36 +1,79 @@ import argparse -def parse_args(args): - """ Parse the arguments. +def check_args(parsed_args): + """ Function to check for inherent contradictions within parsed arguments. + For example, batch_size < num_gpus + Intended to raise errors prior to backend initialisation. + + Args + parsed_args: parser.parse_args() + + Returns + parsed_args """ - parser = argparse.ArgumentParser(description='Simple training script for training a RetinaNet network.') - subparsers = parser.add_subparsers(help='Arguments for specific dataset types.', dest='dataset_type') - subparsers.required = True - coco_parser = subparsers.add_parser('coco') - coco_parser.add_argument('coco_path', help='Path to dataset directory (ie. /tmp/COCO).') + if parsed_args.mode == 'train': + if parsed_args.dataset == None: + raise ValueError("Dataset type should be specified in training mode.") + elif parsed_args.datapath == None: + raise ValueError("Dataset path should be provided.") + + if parsed_args.multi_gpu > 1 and parsed_args.batch_size < parsed_args.multi_gpu: + raise ValueError( + "Batch size ({}) must be equal to or higher than the number of GPUs ({})".format(parsed_args.batch_size, + parsed_args.multi_gpu)) + + if parsed_args.multi_gpu > 1 and parsed_args.snapshot: + raise ValueError( + "Multi GPU training ({}) and resuming from snapshots ({}) is not supported.".format(parsed_args.multi_gpu, + parsed_args.snapshot)) + + if parsed_args.multi_gpu > 1 and not parsed_args.multi_gpu_force: + raise ValueError("Multi-GPU support is experimental, use at own risk! Run with --multi-gpu-force if you wish to continue.") - pascal_parser = subparsers.add_parser('pascal') - pascal_parser.add_argument('pascal_path', help='Path to dataset directory (ie. /tmp/VOCdevkit).') + if 'resnet' not in parsed_args.backbone: + warnings.warn('Using experimental backbone {}. Only resnet50 has been properly tested.'.format(parsed_args.backbone)) - kitti_parser = subparsers.add_parser('kitti') - kitti_parser.add_argument('kitti_path', help='Path to dataset directory (ie. /tmp/kitti).') + return parsed_args - def csv_list(string): - return string.split(',') - oid_parser = subparsers.add_parser('oid') - oid_parser.add_argument('main_dir', help='Path to dataset directory.') - oid_parser.add_argument('--version', help='The current dataset version is v4.', default='v4') - oid_parser.add_argument('--labels-filter', help='A list of labels to filter.', type=csv_list, default=None) - oid_parser.add_argument('--annotation-cache-dir', help='Path to store annotation cache.', default='.') - oid_parser.add_argument('--parent-label', help='Use the hierarchy children of this label.', default=None) - csv_parser = subparsers.add_parser('csv') - csv_parser.add_argument('annotations', help='Path to CSV file containing annotations for training.') - csv_parser.add_argument('classes', help='Path to a CSV file containing class label mapping.') - csv_parser.add_argument('--val-annotations', help='Path to CSV file containing annotations for validation (optional).') +def parse_args(args): + """ + Parse the arguments. + mandatory: + model: [str] retina / yolo / SSD + + optional: + --mode: [str] train / test (default) + --dataset: [str] coco / pascal / kitti / oid / csv (mandatory if it's training mode) + --datapath: [str] (mandatory in training mode) + """ + + parser = argparse.ArgumentParser(description='Simple training script for training a RetinaNet network.') + + # Select which backbone-model to use + model_parsers = parser.add_subparsers(help='Model to use in training/inference', dest='backbone_model') + model_parsers.required = True + + retinanet_parser = model_parsers.add_parser('retina') + retinanet_parser.add_argument('--mode', help='Select train or test mode', default='test', type=str) + retinanet_parser.add_argument('--dataset', help='Arguments for specific datasets', default=None, type=str) + retinanet_parser.add_argument('--datapath', help='Arguments for datasets direcetory', default=None, type=str) + + yolo_parser = model_parsers.add_parser('yolo') + yolo_parser.add_argument('--mode', help='Select train or test mode', default='test', type=str) + yolo_parser.add_argument('--dataset', help='Arguments for specific datasets', default=None, type=str) + yolo_parser.add_argument('--datapath', help='Arguments for datasets direcetory', default=None, type=str) + + ssd_parser = model_parsers.add_parser('SSD') + ssd_parser.add_argument('--mode', help='Select train or test mode', default='test', type=str) + ssd_parser.add_argument('--dataset', help='Arguments for specific datasets', default=None, type=str) + ssd_parser.add_argument('--datapath', help='Arguments for datasets direcetory', default=None, type=str) + + # TODO When using retinanet, there are variations of mandatory arguments according to the dataset used + group = parser.add_mutually_exclusive_group() group.add_argument('--snapshot', help='Resume training from a snapshot.') group.add_argument('--imagenet-weights', help='Initialize the model with pretrained imagenet weights. This is the default behaviour.', action='store_const', const=True, default=True) @@ -61,5 +104,5 @@ def csv_list(string): parser.add_argument('--multiprocessing', help='Use multiprocessing in fit_generator.', action='store_true') parser.add_argument('--workers', help='Number of generator workers.', type=int, default=1) parser.add_argument('--max-queue-size', help='Queue length for multiprocessing workers in fit_generator.', type=int, default=10) - - return parser.parse_args(args) \ No newline at end of file + + return check_args(parser.parse_args(args))