-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
82 lines (73 loc) · 2.78 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CSVLogger
from dgl.dataloading import GraphDataLoader
from solvers.graph_model import AttentionGNN
from utils import sample_dataset, graph_from_problem, os_type, ss_type, GraphDataset
from tqdm.auto import tqdm
import numpy as np
def prepare_train_val_dataset(n_tasks, n_operations, n_cities, path_to_dataset):
dataset = sample_dataset(
100,
[n_tasks, n_tasks],
[n_operations, n_operations],
[n_cities, n_cities],
threshold=0.2,
dirpath=f'{path_to_dataset}/{n_tasks}-{n_operations}-{n_cities}/',
random_seed=0
)
train_dataset = dataset[:80]
val_dataset = dataset[80:]
return train_dataset, val_dataset
def prepare_graphs(dataset, n_tasks, n_operations, n_cities, path_to_dataset):
graphs = []
for problem in dataset:
name = problem['name']
gamma = np.load(f'{path_to_dataset}/{n_tasks}-{n_operations}-{n_cities}/{name}/gamma.npy')
graph = graph_from_problem(problem, gamma, max_operations=n_operations)
graph.edata['feat'][os_type][:, 0] /= 10
graph.edata['feat'][ss_type][:] /= 100
graphs.append(graph)
return graphs
def train():
cases = [
[ 5, 5, 5],
[ 5, 10, 10],
[10, 10, 10],
[ 5, 10, 20],
[ 5, 20, 10],
[ 5, 20, 20],
]
for (n_tasks, n_operations, n_cities) in tqdm(cases, desc='Training GNN models...'):
train_dataset, val_dataset = prepare_train_val_dataset(
n_tasks, n_operations, n_cities, 'data/synthetic')
train_graphs = prepare_graphs(
train_dataset, n_tasks, n_operations, n_cities, 'data/synthetic')
val_graphs = prepare_graphs(
val_dataset, n_tasks, n_operations, n_cities, 'data/synthetic')
train_graph_dataset = GraphDataset(train_graphs)
val_graph_dataset = GraphDataset(val_graphs)
train_dataloader = GraphDataLoader(train_graph_dataset, batch_size=8, shuffle=True)
val_dataloader = GraphDataLoader(val_graph_dataset, batch_size=100)
out_dim = 16 if n_operations == 5 else 32
n_layers = 1 if n_operations == 5 else 3
model = AttentionGNN(
ins_dim=1,
ino_dim=n_operations,
out_dim=out_dim,
n_layers=n_layers,
lr=0.002,
)
trainer = Trainer(
enable_progress_bar=False,
max_epochs=100,
log_every_n_steps=1,
logger=CSVLogger(f'training_gnn/{n_tasks}-{n_operations}-{n_cities}'),
accelerator='cpu',
)
trainer.fit(
model=model,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)
if __name__ == '__main__':
train()