-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdata_preprocess.py
379 lines (315 loc) · 11.1 KB
/
data_preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
"""
Data processing.
#TODO: args, format
"""
import argparse
import json
import networkx as nx
import numpy as np
import random
from rdkit import Chem
from rdkit.Chem import ChemicalFeatures
from rdkit import RDConfig
from utils.file_utils import *
def preprocess_decagon(dir_path='./data/'):
raw_drugs = {}
with open(dir_path + 'drug_raw_feat.idx.jsonl') as f:
for l in f:
idx, l = l.strip().split('\t')
raw_drugs[idx] = json.loads(l)
atom_attr_keys = {a_key for d in raw_drugs.values()
for a in d['atoms'] for a_key in a.keys()}
print('Possible atom attribute names:', atom_attr_keys)
bond_attr_keys = {b_key for d in raw_drugs.values()
for b in d['bonds'] for b_key in b.keys()}
print('Possible bond attribute names:', bond_attr_keys)
# # Pre-process steps:
# ## 1. Calculate the number of Hydrogen for every atom.
# ## 2. Remove Hydrogens in atom list.
# ## 3. Get final existing bonds.
def collate_molecule(molecule, self_loop=True):
atoms = {a['aid']: a for a in molecule['atoms']}
bonds = {}
# adding symmetric bonds: (aid1,aid2) as well as (aid2,aid1)
for b in molecule['bonds']:
for aid_pair in [(b['aid1'], b['aid2']),
(b['aid2'], b['aid1'])]:
bonds[aid_pair] = '{}-{}'.format(b['order'], b.get('style', 0))
if self_loop:
# Add self loops to the set of existing bonds
self_bonds = {(aid, aid): 'self' for aid in atoms}
assert set(self_bonds.keys()) != set(bonds.keys())
bonds = {**bonds, **self_bonds}
new_bonds = {}
# bonds replaces {(b_aid1, b_aid2) : bond_info}
# with {b_aid1: [(b_aid2, bond_info),...]}
for aid in atoms:
atom_vect = []
for (b_aid1, b_aid2), b in bonds.items():
if aid == b_aid1:
atom_vect += [(b_aid2, b)]
new_bonds[aid] = list(atom_vect)
bonds = new_bonds
# Hydrogen bookkeeping
h_aid_set = {aid for aid, atom in atoms.items() if atom['number'] == 1}
# {non-hydrogen aid : number of hydrogen bonds it has}
h_count_dict = {}
for aid, _ in atoms.items():
if aid not in h_aid_set:
hydrogen_neighbour_count = 0
for nbr, _ in bonds[aid]:
if nbr in h_aid_set:
hydrogen_neighbour_count += 1
h_count_dict[aid] = hydrogen_neighbour_count
assert len(h_aid_set) == sum(h_count_dict.values())
assert all([0 == a.get('charge', 0) for a in atoms.values() if a['number'] == 1])
# Remove Hydrogen and use position as new aid
atoms_wo_h_new_aid = {}
# maps from non-hydrogen atoms 'old aid' to
# a record with features + new aid
for idx, (aid, a) in enumerate(
[(aid, a) for aid, a in atoms.items() if a['number'] > 1]):
atoms_wo_h_new_aid[aid] = {
**a,
'charge': a.get('charge', 0),
'n_hydro': h_count_dict.get(aid, 0),
'aid': idx
}
# Update with new aid
bonds_wo_h_new_aid = {}
for aid1, bs in bonds.items():
if aid1 not in h_aid_set:
bonds_wo_h_new_aid[atoms_wo_h_new_aid[aid1]['aid']] =\
[(atoms_wo_h_new_aid[aid2]['aid'], b)
for aid2, b in bs if aid2 not in h_aid_set]
atoms_wo_h_new_aid_w_bond = []
for a in sorted(atoms_wo_h_new_aid.values(), key=lambda x: x['aid']):
atoms_wo_h_new_aid_w_bond += [
# adding the neighbour list to the record information
{**a, 'nbr': bonds_wo_h_new_aid[a['aid']]}]
assert all(i == a['aid'] for i, a in enumerate(atoms_wo_h_new_aid_w_bond))
return atoms_wo_h_new_aid_w_bond
drug_structure_dict = {cid: collate_molecule(d, self_loop=True)
for cid, d in raw_drugs.items()}
#set of bond types (order-style or self) from all molecules in the data set
bond_types = {b for d in drug_structure_dict.values()
for a in d for _, b in a['nbr']}
# assign a number to each distinct bond type
bond_type_idx = {b: i for i, b in enumerate(bond_types)}
print('Bond to idx dict:', bond_type_idx)
def build_graph_idx_mapping(molecule, bond_type_idx=None):
atom_type = []
atom_feat = []
bond_type = []
bond_seg_i = []
bond_idx_j = []
for i, atom in enumerate(molecule):
aid = atom['aid']
assert aid == i
# atom['nbr'] is of form [(id, bond.order-bond.style)]
# *atom['nbr'] takes each pair and considers it as a different input
# zip(*atom['nbr']) build
# iterator over ids (id1, id2, ...)
# and iterator over bond info ( bond-order-bond.style1, ...)
nbr_ids, nbr_bonds = zip(*atom['nbr'])
assert len(set(nbr_ids)) == len(nbr_ids), 'Multi-graph is not supported.'
if bond_type_idx:
nbr_bonds = list(map(bond_type_idx.get, nbr_bonds))
# Follow position i
atom_feat += [(atom['number'], atom['n_hydro'], atom['charge'])]
atom_type += [atom['number']]
# Follow position i
bond_type += nbr_bonds
# Follow aid
# list with i repeated x times (x is how many bonds i has)
bond_seg_i += [aid] * len(nbr_ids)
bond_idx_j += nbr_ids
return {'n_atom': len(molecule),
'atom_type': atom_type,
'atom_feat': atom_feat,
'bond_type': bond_type,
'bond_seg_i': bond_seg_i,
'bond_idx_j': bond_idx_j}
drug_graph_dict = {
cid: build_graph_idx_mapping(d, bond_type_idx=bond_type_idx)
for cid, d in drug_structure_dict.items()}
# # Write to jsonl file
with open(dir_path + 'drug.feat.wo_h.self_loop.idx.jsonl', 'w') as f:
for cid, d in drug_graph_dict.items():
f.write('{}\t{}\n'.format(cid, json.dumps(d)))
with open(dir_path + 'drug.bond_idx.wo_h.self_loop.json', 'w') as f:
f.write(json.dumps(bond_type_idx))
def preprocess_qm9(dir_path='./data/qm9/dsgdb9nsd'):
# Initialization of graph for QM9
def extract_graph_properties(prop):
prop = prop.split()
g_tag = prop[0]
g_index = int(prop[1])
g_A = float(prop[2])
g_B = float(prop[3])
g_C = float(prop[4])
g_mu = float(prop[5])
g_alpha = float(prop[6])
g_homo = float(prop[7])
g_lumo = float(prop[8])
g_gap = float(prop[9])
g_r2 = float(prop[10])
g_zpve = float(prop[11])
g_U0 = float(prop[12])
g_U = float(prop[13])
g_H = float(prop[14])
g_G = float(prop[15])
g_Cv = float(prop[16])
labels = [g_mu, g_alpha, g_homo, g_lumo, g_gap, g_r2, g_zpve, g_U0, g_U, g_H, g_G, g_Cv]
return {
"tag": g_tag,
"index": g_index,
"A": g_A,
"B": g_B,
"C": g_C,
"mu": g_mu,
"alpha": g_alpha,
"homo": g_homo,
"lumo": g_lumo,
"gap": g_gap,
"r2": g_r2,
"zpve": g_zpve,
"U0": g_U0,
"U": g_U,
"H": g_H,
"G": g_G,
"Cv": g_Cv}, labels
def get_bond_type_idx(bond_type):
# There are 4 types of bonds in QM9
# 1 - single, 2 - double, 3 - triple
# 4 (originally 2.5) - aromatic
# 0 - self
if bond_type == 1.5:
return 4
return int(bond_type)
# XYZ file reader for QM9 dataset
def xyz_graph_reader(graph_file, self_loop = True):
with open(graph_file, 'r') as f:
mol_representation = {
'n_atom': None,
'atom_type': None,
'atom_feat': None,
'bond_type': None,
'bond_seg_i': None,
'bond_idx_j': None,
}
# Number of atoms
n_atom = int(f.readline())
# Graph properties
properties = f.readline()
prop_dict, labels = extract_graph_properties(properties)
mol_idx = prop_dict['index']
atom_properties = []
# Atoms properties
for i in range(n_atom):
a_properties = f.readline()
a_properties = a_properties.replace('.*^', 'e')
a_properties = a_properties.replace('*^', 'e')
a_properties = a_properties.split()
atom_properties.append(a_properties)
# Frequencies
f.readline()
# SMILES
smiles = f.readline()
smiles = smiles.split()
smiles = smiles[0]
m = Chem.MolFromSmiles(smiles)
m = Chem.AddHs(m)
assert n_atom == m.GetNumAtoms()
fdef_name = os.path.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
factory = ChemicalFeatures.BuildFeatureFactory(fdef_name)
feats = factory.GetFeaturesForMol(m)
atom_type = []
atom_feat = []
atom_feat_dicts = []
bond_type = []
bond_seg_i = []
bond_idx_j = []
# distances between atoms
bond_dist = []
# Create nodes
for i in range(0, m.GetNumAtoms()):
atom_i = m.GetAtomWithIdx(i)
atom_type += [atom_i.GetAtomicNum()]
atom_feat += [(atom_i.GetAtomicNum(), atom_i.GetTotalNumHs(),atom_i.GetFormalCharge())]
atom_feat_dicts += [{
'number' : atom_i.GetAtomicNum(),
'n_hydro': atom_i.GetTotalNumHs(),
'charge': atom_i.GetFormalCharge(),
'coord': np.array(atom_properties[i][1:4]).astype(np.float)}]
# TODO: should use the other features too?
# For example, coordinates, hybridization, aromatic
#g.add_node(i, a_type=atom_i.GetSymbol(), a_num=atom_i.GetAtomicNum(), acceptor=0, donor=0,
# aromatic=atom_i.GetIsAromatic(), hybridization=atom_i.GetHybridization(),
# num_h=atom_i.GetTotalNumHs(), coord=np.array(atom_properties[i][1:4]).astype(np.float),
# pc=float(atom_properties[i][4]))
"""
for i in range(0, len(feats)):
if feats[i].GetFamily() == 'Donor':
node_list = feats[i].GetAtomIds()
for i in node_list:
g.node[i]['donor'] = 1
elif feats[i].GetFamily() == 'Acceptor':
node_list = feats[i].GetAtomIds()
for i in node_list:
g.node[i]['acceptor'] = 1
"""
# Read Edges
for i in range(0, m.GetNumAtoms()):
for j in range(0, m.GetNumAtoms()):
e_ij = m.GetBondBetweenAtoms(i, j)
if e_ij is not None:
bond_type += [get_bond_type_idx(e_ij.GetBondTypeAsDouble())]
bond_seg_i += [i]
bond_idx_j += [j]
bond_dist += [np.linalg.norm(atom_feat_dicts[i]['coord'] - atom_feat_dicts[j]['coord'])]
if self_loop:
#add self edge as type 0
bond_type += [0]
bond_seg_i += [i]
bond_idx_j += [i]
bond_dist += [0.]
mol_representation['n_atom'] = n_atom
mol_representation['atom_type'] = atom_type
mol_representation['atom_feat'] = atom_feat
mol_representation['bond_type'] = bond_type
mol_representation['bond_seg_i'] = bond_seg_i
mol_representation['bond_idx_j'] = bond_idx_j
return mol_idx, mol_representation, labels, smiles
# # Write to jsonl file
files = [f for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f))]
with open(dir_path + 'viz_drug.feat.self_loop.idx.jsonl', 'w') as f:
with open(dir_path + 'viz_drug.labels.jsonl', 'w') as g:
with open(dir_path + 'viz_smiles.jsonl', 'w') as h:
for file in files:
if ".xyz" in file and file in ["dsgdb9nsd_047721.xyz","dsgdb9nsd_040834.xyz"]:
mol_idx, mol_representation, labels, smiles = xyz_graph_reader(os.path.join(dir_path, file))
f.write('{}\t{}\n'.format(mol_idx, json.dumps(mol_representation)))
g.write('{}\t{}\n'.format(mol_idx, json.dumps(labels)))
h.write('{}\t{}\t{}\n'.format(file, mol_idx, json.dumps(smiles)))
def main():
parser = argparse.ArgumentParser(description='Download dataset for Graph Co-attention')
parser.add_argument('datasets', metavar='D', type=str.lower, nargs='+', choices=['qm9', 'decagon'],
help='Name of dataset to download [QM9,DECAGON]')
# I/O
parser.add_argument('-p', '--path', metavar='dir', type=str, nargs=1,
help="path to store the data (default ./data/)")
args = parser.parse_args()
# Check parameters
if args.path is None:
args.path = './data/'
else:
args.path = args.path[0]
# Init folder
prepare_data_dir(args.path)
if 'qm9' in args.datasets:
preprocess_qm9(args.path + 'qm9/' + 'dsgdb9nsd/')
if 'decagon' in args.datasets:
preprocess_decagon(args.path + 'decagon/')
if __name__ == "__main__":
main()