Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Argument Parser for general use #60

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/retinanet_inference_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import yolk
from yolk.parser import parse_args


def main(args=None):
if args is None:
args = sys.argv[1:]
Expand Down
88 changes: 63 additions & 25 deletions yolk/parser.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,74 @@
import argparse

def parse_args(args):
""" Parse the arguments.
def check_args(parsed_args):
MijeongJeon marked this conversation as resolved.
Show resolved Hide resolved
""" Function to check for inherent contradictions within parsed arguments.
For example, batch_size < num_gpus
Intended to raise errors prior to backend initialisation.

Args
parsed_args: parser.parse_args()

Returns
parsed_args
"""
parser = argparse.ArgumentParser(description='Simple training script for training a RetinaNet network.')
subparsers = parser.add_subparsers(help='Arguments for specific dataset types.', dest='dataset_type')
subparsers.required = True

coco_parser = subparsers.add_parser('coco')
coco_parser.add_argument('coco_path', help='Path to dataset directory (ie. /tmp/COCO).')
if parsed_args.mode == 'train' and parsed_args.dataset == None:
raise ValueError(
"Dataset type should be specified in training mode."
)

if parsed_args.multi_gpu > 1 and parsed_args.batch_size < parsed_args.multi_gpu:
raise ValueError(
"Batch size ({}) must be equal to or higher than the number of GPUs ({})".format(parsed_args.batch_size,
parsed_args.multi_gpu))

if parsed_args.multi_gpu > 1 and parsed_args.snapshot:
raise ValueError(
"Multi GPU training ({}) and resuming from snapshots ({}) is not supported.".format(parsed_args.multi_gpu,
parsed_args.snapshot))

pascal_parser = subparsers.add_parser('pascal')
pascal_parser.add_argument('pascal_path', help='Path to dataset directory (ie. /tmp/VOCdevkit).')
if parsed_args.multi_gpu > 1 and not parsed_args.multi_gpu_force:
raise ValueError("Multi-GPU support is experimental, use at own risk! Run with --multi-gpu-force if you wish to continue.")

kitti_parser = subparsers.add_parser('kitti')
kitti_parser.add_argument('kitti_path', help='Path to dataset directory (ie. /tmp/kitti).')
if 'resnet' not in parsed_args.backbone:
warnings.warn('Using experimental backbone {}. Only resnet50 has been properly tested.'.format(parsed_args.backbone))

def csv_list(string):
return string.split(',')
return parsed_args

oid_parser = subparsers.add_parser('oid')
oid_parser.add_argument('main_dir', help='Path to dataset directory.')
oid_parser.add_argument('--version', help='The current dataset version is v4.', default='v4')
oid_parser.add_argument('--labels-filter', help='A list of labels to filter.', type=csv_list, default=None)
oid_parser.add_argument('--annotation-cache-dir', help='Path to store annotation cache.', default='.')
oid_parser.add_argument('--parent-label', help='Use the hierarchy children of this label.', default=None)

csv_parser = subparsers.add_parser('csv')
csv_parser.add_argument('annotations', help='Path to CSV file containing annotations for training.')
csv_parser.add_argument('classes', help='Path to a CSV file containing class label mapping.')
csv_parser.add_argument('--val-annotations', help='Path to CSV file containing annotations for validation (optional).')


def parse_args(args):
"""
Parse the arguments.
backbone-model train/test dataset-type
default:
- backbone-model : retina
- train/test : test
"""
parser = argparse.ArgumentParser(description='Simple training script for training a RetinaNet network.')

# Select which backbone-model to use
model_parsers = parser.add_subparsers(help='Model to use in training/inference', dest='backbone_model')
model_parsers.required = True

retinanet_parser = model_parsers.add_parser('retina')
retinanet_parser.add_argument('--mode', help='Select train or test mode', default='test', type=str)
retinanet_parser.add_argument('--dataset', help='Arguments for specific datasets', default=None, type=str)
retinanet_parser.add_argument('--datapath', help='Arguments for datasets direcetory', default=None, type=str)

yolo_parser = model_parsers.add_parser('yolo')
yolo_parser.add_argument('--mode', help='Select train or test mode', default='test', type=str)
yolo_parser.add_argument('--dataset', help='Arguments for specific datasets', default=None, type=str)
yolo_parser.add_argument('--datapath', help='Arguments for datasets direcetory', default=None, type=str)

ssd_parser = model_parsers.add_parser('SSD')
ssd_parser.add_argument('--mode', help='Select train or test mode', default='test', type=str)
ssd_parser.add_argument('--dataset', help='Arguments for specific datasets', default=None, type=str)
ssd_parser.add_argument('--datapath', help='Arguments for datasets direcetory', default=None, type=str)

# TODO When using retinanet, there are variations of mandatory arguments according to the dataset used

group = parser.add_mutually_exclusive_group()
group.add_argument('--snapshot', help='Resume training from a snapshot.')
group.add_argument('--imagenet-weights', help='Initialize the model with pretrained imagenet weights. This is the default behaviour.', action='store_const', const=True, default=True)
Expand Down Expand Up @@ -61,5 +99,5 @@ def csv_list(string):
parser.add_argument('--multiprocessing', help='Use multiprocessing in fit_generator.', action='store_true')
parser.add_argument('--workers', help='Number of generator workers.', type=int, default=1)
parser.add_argument('--max-queue-size', help='Queue length for multiprocessing workers in fit_generator.', type=int, default=10)

return parser.parse_args(args)
return check_args(parser.parse_args(args))