-
Notifications
You must be signed in to change notification settings - Fork 4
/
config_pretext.yaml
28 lines (24 loc) · 1.85 KB
/
config_pretext.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
batch_size: 64 # training/evaluation batch size
epochs: 1 # number of epochs to train. Ignored when running evaluation
validate_every_n_epochs: 1 # how often to run validation (epochs)
eval_every_n_epochs: 1 # how often to run evaluation (epochs)
fine_tune_from: None # path to model to fine tune/evaluate
log_every_n_steps: 100 # how often to log to tensorboard (steps)
weight_decay: 10e-6 # training weight decay
model: # model parameters
out_dim: 128 # size of the projection output
base_model: "scatsimclr30" # base model to be used. Choices: "resnet18", "resnet50", "scatsimclr8", "scatsimclr12", "scatsimclr16", "scatsimclr30"
J: 2 # J - scale parameter of ScatNet. More: https://www.kymat.io/
L: 16 # L - rotation parameter of ScatNet. More: https://www.kymat.io/
dataset: # dataset parameters
dataset: "stl10" # dataset name: "stl10", "cifar20", "cifar10"
input_shape: (96, 96, 3) # input shape. For "stl10" - (96, 96, 3), for "cifar20", "cifar10" - (32, 32, 3)
valid_size: 0.05 # percentage of valid data
loss: # loss parameters
temperature: 0.5
use_cosine_similarity: True
pretext: # pretext learning parameters
jigsaw: False # if True, jigsaw pretext task will be used. Only one of `jigsaw` or `rotation` can be True
rotation: True # if True, rotation pretext task will be used.
num_jigsaw: 35 # number of jigsaw permutation. Ignored, when rotation. Note that each num_jigsaw requires file with permutations in data folder named permutations_{num_jigsaw}.npy
lambda: 0.3 # loss weight parameter for pretext task