-
Notifications
You must be signed in to change notification settings - Fork 17
/
data.py
68 lines (58 loc) · 2.37 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from torch.utils.data import Dataset
import os
import pickle
import numpy as np
import networkx as nx
import dgl
from collections import Counter
class RetroCenterDatasets(Dataset):
def __init__(self, root, data_split):
self.root = root
self.data_split = data_split
self.data_dir = os.path.join(root, self.data_split)
self.data_files = [
f for f in os.listdir(self.data_dir) if f.endswith('.pkl')
]
self.data_files.sort()
self.disconnection_num = []
cnt = Counter()
for data_file in self.data_files:
with open(os.path.join(self.data_dir, data_file), 'rb') as f:
reaction_data = pickle.load(f)
xa = reaction_data['product_adj']
ya = reaction_data['target_adj']
res = xa & (ya == False)
res = np.sum(np.sum(res)) // 2
cnt[res] += 1
if res >= 2:
res = 2
self.disconnection_num.append(res)
print(cnt)
def __getitem__(self, index):
with open(os.path.join(self.data_dir, self.data_files[index]),
'rb') as f:
reaction_data = pickle.load(f)
x_atom = reaction_data['product_atom_features'].astype(np.float32)
x_pattern_feat = reaction_data['pattern_feat'].astype(np.float32)
x_bond = reaction_data['product_bond_features'].astype(np.float32)
x_adj = reaction_data['product_adj']
y_adj = reaction_data['target_adj']
rxn_class = reaction_data['rxn_type']
rxn_class = np.eye(10)[rxn_class]
product_atom_num = len(x_atom)
rxn_class = np.expand_dims(rxn_class, 0).repeat(product_atom_num,
axis=0)
disconnection_num = self.disconnection_num[index]
# Construct graph and add edge data
x_graph = dgl.DGLGraph(nx.from_numpy_matrix(x_adj))
x_graph.edata['w'] = x_bond[x_adj]
return rxn_class, x_pattern_feat, x_atom, x_adj, x_graph, y_adj, disconnection_num
def __len__(self):
return len(self.data_files)
if __name__ == '__main__':
savedir = 'data/USPTO50K/'
for data_set in ['train', 'test', 'valid']:
save_dir = os.path.join(savedir, data_set)
train_data = RetroCenterDatasets(root=savedir, data_split=data_set)
print(train_data.data_files[:100])