-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathconfig.py
72 lines (61 loc) · 1.37 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
from torch import optim, nn
import multiprocessing as mp
from utils import datasets, samplers, schedulers
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_workers = mp.cpu_count()
epochs = 100
model = dict(
# cls=model.Model,
# params=dict(
# # Named arguments for Model constructor
# )
)
loss_fn = dict(
# cls=nn.MSELoss,
# params=dict(
# # Named arguments for loss constructor
# )
)
optimizer = dict(
# cls=optim.Adam,
# params=dict(
# # Named arguments for optimizer constructor
# )
)
scheduler = dict(
# cls=schedulers.Scheduler,
# params=dict(
# # Named arguments for Scheduler constructor
# )
)
train_dataset = dict(
# cls=datasets.Dataset,
# params=dict(
# # Named arguments for Dataset constructor
# ),
# loader=dict(
# # Named arguments for DataLoader
# )
)
val_dataset = dict(
# cls=datasets.Dataset,
# params=dict(
# # Named arguments for Dataset constructor
# ),
# loader=dict(
# # Named arguments for DataLoader
# )
)
train_batch_sampler = dict(
# cls=samplers.Sampler,
# params=dict(
# # Named arguments for Sampler constructor
# )
)
val_batch_sampler = dict(
# cls=samplers.Sampler,
# params=dict(
# # Named arguments for Sampler constructor
# )
)