-
Notifications
You must be signed in to change notification settings - Fork 2
/
03_pg_dataset.py
34 lines (26 loc) · 1.28 KB
/
03_pg_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
"""Third step of our approach: create Pytorch Geometric train and val dataset.
Takes a long time to run if you create a new one.
"""
import src.config as cfg
from src.davis_2016 import DAVIS2016
if __name__ == "__main__":
train = DAVIS2016(cfg.PYTORCH_GEOMETRIC_DAVIS_2016_DATASET_PATH,
cfg.ANNOTATIONS_AUGMENTED_FOLDERS_PATH,
cfg.CONTOURS_FOLDERS_PATH,
cfg.IMAGES_AUGMENTED_FOLDERS_PATH, cfg.TRANSLATIONS_FOLDERS_PATH,
cfg.PARENT_MODEL_PATH,
cfg.LAYER, cfg.K, cfg.AUGMENTATION_COUNT,
cfg.SKIP_SEQUENCES,
cfg.TRAIN_SEQUENCES, cfg.VAL_SEQUENCES,
train=True)
val = DAVIS2016(cfg.PYTORCH_GEOMETRIC_DAVIS_2016_DATASET_PATH,
cfg.ANNOTATIONS_AUGMENTED_FOLDERS_PATH,
cfg.CONTOURS_FOLDERS_PATH,
cfg.IMAGES_AUGMENTED_FOLDERS_PATH, cfg.TRANSLATIONS_FOLDERS_PATH,
cfg.PARENT_MODEL_PATH,
cfg.LAYER, cfg.K, 0,
cfg.SKIP_SEQUENCES,
cfg.TRAIN_SEQUENCES, cfg.VAL_SEQUENCES,
train=False)
print("Train size: %i" % len(train))
print("Val size: %i" % len(val))