From 1789beed248e5a934ac6ad313cadfd8f3f2221bc Mon Sep 17 00:00:00 2001 From: icolbert Date: Wed, 13 Sep 2023 16:52:45 -0700 Subject: [PATCH] Update eval_model.py --- .../super_resolution/eval_model.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/brevitas_examples/super_resolution/eval_model.py b/src/brevitas_examples/super_resolution/eval_model.py index a567469d1..60883dea6 100644 --- a/src/brevitas_examples/super_resolution/eval_model.py +++ b/src/brevitas_examples/super_resolution/eval_model.py @@ -31,13 +31,22 @@ parser = argparse.ArgumentParser(description='PyTorch BSD300 Validation') parser.add_argument('--data_root', help='Path to folder containing BSD300 val folder') -parser.add_argument('--model_path', default=None, help='Path to PyTorch checkpoint') +parser.add_argument('--model_path', default=None, help='Path to PyTorch checkpoint. Default = None') parser.add_argument( - '--save_path', type=str, default='outputs/', help='Save path for exported model') + '--save_path', + type=str, + default='outputs/', + help='Save path for exported model. Default = outputs/') parser.add_argument( - '--model', type=str, default='quant_espcn_x2_w8a8_base', help='Name of the model configuration') -parser.add_argument('--workers', type=int, default=0, help='Number of data loading workers') -parser.add_argument('--batch_size', type=int, default=16, help='Minibatch size') + '--model', + type=str, + default='quant_espcn_x2_w8a8_base', + help='Name of the model configuration. Default = quant_espcn_x2_w8a8_base') +parser.add_argument( + '--workers', type=int, default=0, help='Number of data loading workers. Default = 0') +parser.add_argument('--batch_size', type=int, default=16, help='Minibatch size. Default = 16') +parser.add_argument( + '--crop_size', type=int, default=512, help='The size to crop the image. Default = 512') parser.add_argument('--use_pretrained', action='store_true', default=False) parser.add_argument('--eval_acc_bw', action='store_true', default=False) parser.add_argument('--save_model_io', action='store_true', default=False) @@ -60,6 +69,7 @@ def main(): num_workers=args.workers, batch_size=args.batch_size, upscale_factor=model.upscale_factor, + crop_size=args.crop_size, download=True) test_psnr = evaluate_avg_psnr(testloader, model)