-
Notifications
You must be signed in to change notification settings - Fork 7
/
config.py
57 lines (43 loc) · 1.46 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
class CFG:
# resume training
resume = False
resume_loss = float('inf')
resume_path = ''
start_epoch = 0
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# model config
max_len = 300
img_size = 384
num_bins = img_size
num_classes = 91
batch_size = 16
epochs = 2
model_name = 'deit3_small_patch16_384_in21ft1k'
num_patches = 576
# voc dataset params
img_path = '../download/VOCdevkit/VOC2012/JPEGImages'
xml_path = '../download/VOCdevkit/VOC2012/Annotations'
voc_label_path = '../train/voc_classes.txt'
voc_weight_path = '../weights/voc_object_detection.pth'
# coco dataset params
dir_root = '/mnt/MSCOCO'
coco_label_path = '../train/coco91_indices.json'
coco_weight_path = '../weights/coco_ob_wo_pixnorm.pth'
# image captioning params
vocab_path = '../train/vocab.pkl'
coco_caption_weight_path = '../weights/coco_image_caption.pth'
# keypoint detection params
keypoints_path = '../train/person_keypoints.json'
coco_keypoint_weight_path = '../weights/coco_keypoint.pth'
num_joints = 17 # for coco dataset, no use
# segmentation params
coco_seg_weight_path = '../weights/coco_segmentation.pth'
# multi task params
multi_task_weight_path = '../weights/coco_multi_task.pth'
# optim
lr = 1e-4
weight_decay = 1e-4
# eval
generation_steps = 101
run_eval = True