-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathparser.py
81 lines (67 loc) · 4.31 KB
/
parser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import argparse
# common settings for pretraining CrossPoint
parser = argparse.ArgumentParser(description='CrossPoint for Point Cloud Understanding')
parser.add_argument('--exp_name', type=str, default='exp', metavar='N',
help='Name of the experiment')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--model', type=str, default='dgcnn', metavar='N',
choices=['dgcnn', 'dgcnn_seg'],
help='Model to use, [pointnet, dgcnn]')
parser.add_argument('--resume', action="store_true", help='resume from checkpoint')
parser.add_argument('--eval', action='store_true', help='evaluate the model')
parser.add_argument('--batch_size', type=int, default=16, metavar='batch_size',
help='Size of batch)')
parser.add_argument('--test_batch_size', type=int, default=16, metavar='batch_size',
help='Size of test batch)')
parser.add_argument('--epochs', type=int, default=250, metavar='N',
help='number of episode to train ')
parser.add_argument('--num_workers', type=int, default=0, metavar='num_workers',
help='number of processes to load data in host memory')
parser.add_argument('--use_sgd', action="store_true", help='Use SGD')
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
help='learning rate (default: 0.001, 0.1 if using sgd)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--num_pt_points', type=int, default=2048,
help='num of points when pretraining')
parser.add_argument('--num_ft_points', type=int, default=1024,
help='Number of points when finetuning')
parser.add_argument('--num_classes', type=int, default=40,
help='num of object classes in a dataset')
parser.add_argument('--ft_dataset', type=str, default='ModelNet40', help='finetune dataset')
parser.add_argument('--dropout', type=float, default=0.5,
help='dropout rate')
parser.add_argument('--emb_dims', type=int, default=1024, metavar='N',
help='Dimension of embeddings')
parser.add_argument('--k', type=int, default=20, metavar='N',
help='Num of nearest neighbors to use')
parser.add_argument('--model_path', type=str, default='', metavar='N',
help='saved model path')
parser.add_argument('--img_model_path', type=str, default='', metavar='N',
help='saved image model path')
parser.add_argument('--save_freq', type=int, default=50, help='save frequency')
parser.add_argument('--print_freq', type=int, default=200, help='print frequency')
# training on single GPU device
parser.add_argument('--no_cuda', type=bool, default=False,
help='enables CUDA training')
parser.add_argument('--gpu_id', type=int, default=0, help='specify the GPU device'
'to train of finetune model')
# distributed training on multiple GPUs
parser.add_argument('--rank', type=int, default=-1, help='the rank for current GPU or process, '
'ususally one process per GPU')
parser.add_argument('--backend', type=str, default='nccl', help='DDP communication backend')
parser.add_argument('--world_size', type=int, default=6, help='number of GPUs')
parser.add_argument('--master_addr', type=str, default='localhost', help='ip of master node')
parser.add_argument('--master_port', type=str, default='12355', help='port of master node')
# downstream task: Segmentation settings
parser.add_argument('--class_choice', type=str, default=None, metavar='N',
choices=['airplane', 'bag', 'cap', 'car', 'chair',
'earphone', 'guitar', 'knife', 'lamp', 'laptop',
'motor', 'mug', 'pistol', 'rocket', 'skateboard', 'table'])
parser.add_argument('--scheduler', type=str, default='cos', metavar='N',
choices=['cos', 'step'], help='Scheduler to use, [cos, step]')
# wandb settings
parser.add_argument('--wb_url', type=str, default="http://localhost:28282", help='your wandb server url')
parser.add_argument('--wb_key', type=str, default="", help='your wandb login key')
args = parser.parse_args()