-
Notifications
You must be signed in to change notification settings - Fork 0
/
arguments.py
40 lines (36 loc) · 3.2 KB
/
arguments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import argparse
def get_args_parser(add_help=True):
parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help)
parser.add_argument("--data_dir", default="data", type=str, help="path to dataset")
parser.add_argument("--checkpoint", default=None, type=str, help="path to model checkpoint")
parser.add_argument("--ae_checkpoint", default='checkpoints/plain-ae_latent256_chan64_pool-stride_upsample-deconv_bs50_lr0.001_wd0.01_e100_full_model.pth', type=str, help="path to AE checkpoint")
parser.add_argument("--classifier_checkpoint", default=None, type=str, help="path to classifier checkpoint")
parser.add_argument("--tag", default="", type=str, help="string to prepend when saving checkpoints")
parser.add_argument("--debug", dest="debug", help="print out shapes and values of intermediate outputs", action="store_true")
parser.add_argument("--log_tb", dest="log_tb", help="enable tensorboard plotting", action="store_true")
parser.add_argument("--tb_dir", default="logs", type=str, help="path to tensorboard files")
parser.add_argument("--loss", default="mse", type=str, help="reconstruction loss function")
parser.add_argument("--train_batch_size", default=50, type=int)
parser.add_argument("--test_batch_size", default=500, type=int)
parser.add_argument("--latent_size", default=256, type=int)
parser.add_argument("--num_channels", default=64, type=int)
parser.add_argument("--kernel_size", default=3, type=int)
parser.add_argument("--print_freq", default=1000, type=int)
parser.add_argument("--epochs", default=100, type=int)
parser.add_argument("--lr", default=0.001, type=float)
parser.add_argument("--wd", default=0.0, type=float)
parser.add_argument("--dropout", default=0, type=float)
parser.add_argument("--act", default='gelu', type=str, help='relu, leaky-relu, elu, selu, gelu, swish, mish')
parser.add_argument("--beta", default=5e-6, type=float)
parser.add_argument("--beta_mult", default=1.02, type=float)
parser.add_argument("--sample", dest="sample", help="generate an image from a random latent vector", action="store_true")
parser.add_argument("--evaluate", dest="evaluate", help="use pretrained model to reconstruct 4 images", action="store_true")
parser.add_argument("--train", dest="train", help="train model", action="store_true")
parser.add_argument("--variational", dest="variational", help="use variational autoencoder, otherwise use a plain one", action="store_true")
parser.add_argument("--sigmoid", dest="sigmoid", help="apply sigmoid at the end of decoder", action="store_true")
parser.add_argument("--mse", dest="mse", help="ise MSE loss instead of KL loss", action="store_true")
parser.add_argument("--no_pool", dest="no_pool", help="don't use max pooling for downsampling, use stride=2 conv", action="store_true")
parser.add_argument("--no_upsample", dest="no_upsample", help="don't use nn.Upsample for upsampling, use deconv", action="store_true")
parser.add_argument("--interpolation", default="nearest", type=str, help=f"upsample interpolation mode in the decoder, "
f"'nearest', 'linear', 'bilinear', 'bicubic' and 'trilinear'")
return parser