forked from ridhipatil/RL_complex_detection
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpostprocess_sc.py
291 lines (240 loc) · 11.6 KB
/
postprocess_sc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
# -*- coding: utf-8 -*-
"""
Created on Thu Mar 5 17:22:47 2020
@author: Meghana
"""
from logging import info as logging_info, debug as logging_debug
from networkx import number_of_nodes as nx_number_of_nodes, write_weighted_edgelist as nx_write_weighted_edgelist, \
Graph as nx_Graph
from jaccard_coeff import jaccard_coeff
from numpy import argmax as np_argmax, argsort as np_argsort
from pickle import load as pickle_load
from convert_humap_ids2names import convert2names_wscores
import networkx as nx
def filter_overlapped(list_comp, inputs):
logging_info("Filtering complexes...")
# Sort by size
sizes = [nx_number_of_nodes(comp) for comp in list_comp]
sorted_ind = np_argsort(sizes) # ascending order.
list_comp = [list_comp[i] for i in sorted_ind]
fin_list = list(list_comp)
list_comp2 = list(list_comp)
# print(len(list_comp))
# Ensure ascending order
for comp in list_comp:
OS_comp = []
list_comp2.remove(comp)
if len(list_comp2):
for comp2 in list_comp2:
Overlap = jaccard_coeff(comp.nodes(), comp2.nodes())
OS_comp.append(Overlap)
OS_max = max(OS_comp)
# print(OS_max)
if OS_max > inputs['over_t']:
fin_list.remove(comp)
logging_info("Finished filtering complexes.")
return fin_list
def NA(set1, set2):
ls1 = len(set1)
ls2 = len(set2)
if ls1 == 0 and ls2 == 0:
return 1
inter = float(len(set1.intersection(set2)))
a = inter/ls1
b = inter/ls2
return a,b,a*b
def merge_filter_overlapped_score_qi(list_comp, inputs, G,scores):
logging_info("Filtering complexes...")
fin_list = list(list_comp)
n = len(fin_list)
if n <= 1:
return fin_list
n_changes = 1
while n_changes != 0:
if len(fin_list) == 1:
logging_debug("only one complex")
break
n_changes = 0
ind = 0
while ind < n:
if len(fin_list) == 1:
logging_debug("only one complex")
break
else:
comp = fin_list[ind]
temp_list = list(fin_list)
temp_list.remove(comp)
OS_comp_3 = [NA(comp[0], comp2[0]) for comp2 in temp_list]
OS_comp = [c for a,b,c in OS_comp_3]
OS_max_ind = int(np_argmax(OS_comp))
OS_max_3 = OS_comp_3[OS_max_ind]
max_over_comp = temp_list[OS_max_ind]
OS_max_ind_fin = fin_list.index(max_over_comp)
if OS_max_3[0] >= inputs['over_t'] and OS_max_3[1] >= inputs['over_t']:
n_changes += 1
n -= 1
# Merge and find score. If score is higher than individual complexes
# Keep as new complex
merge_comp_nodes = comp[0].union(max_over_comp[0])
# Rather than subgraph operation which requires the full graph,
# you can add only additional edges from the node adjacency lists
merge_comp = nx_Graph(G.subgraph(merge_comp_nodes), comp_score=0)
value_functions = dict(scores)
dens_merge = nx.density(merge_comp)
intervals = [0.05,0.1,0.15,0.2,0.25, 0.3,0.35, 0.4,0.45, 0.5,0.55, 0.6,0.65, 0.7,0.75, 0.8,0.85, 0.9,0.95, 1]
temp_dens = 0
for i in intervals:
if dens_merge <= i:
temp_dens = i
break
else:
continue
score_merge = value_functions[temp_dens]
merge_comp.graph['comp_score'] = score_merge
sc1 = comp[1]
sc2 = max_over_comp[1]
if score_merge > sc1 and score_merge > sc2:
fin_list.append((frozenset(merge_comp.nodes()), merge_comp.graph['comp_score']))
fin_list.remove(comp)
fin_list.remove(max_over_comp)
if OS_max_ind_fin <= ind:
ind -= 1
# Otherwise: remove lower scoring complex
elif sc1 <= sc2:
fin_list.remove(comp)
else:
fin_list.remove(max_over_comp)
if OS_max_ind_fin > ind:
ind += 1
else:
ind += 1
logging_info("No. of changes = %s", str(n_changes))
logging_info("Finished filtering complexes.")
return fin_list
def merge_filter_overlapped_score(list_comp, inputs, G):
logging_info("Filtering complexes...")
fin_list = list(list_comp)
n = len(fin_list)
if n <= 1:
return fin_list
n_changes = 1
while n_changes != 0:
if len(fin_list) == 1:
logging_debug("only one complex")
break
n_changes = 0
ind = 0
while ind < n:
if len(fin_list) == 1:
logging_debug("only one complex")
break
else:
comp = fin_list[ind]
temp_list = list(fin_list)
temp_list.remove(comp)
OS_comp = [jaccard_coeff(comp[0], comp2[0]) for comp2 in temp_list]
OS_max_ind = int(np_argmax(OS_comp))
OS_max = OS_comp[OS_max_ind]
max_over_comp = temp_list[OS_max_ind]
OS_max_ind_fin = fin_list.index(max_over_comp)
if OS_max >= inputs['over_t']:
n_changes += 1
n -= 1
# Merge and find score. If score is higher than individual complexes
# Keep as new complex
merge_comp_nodes = comp[0].union(max_over_comp[0])
# Rather than subgraph operation which requires the full graph,
# you can add only additional edges from the node adjacency lists
merge_comp = nx_Graph(G.subgraph(merge_comp_nodes), comp_score=0)
with open(args.pred_results+'/value_fns_interp.txt', 'rb') as f:
value_functions = pickle_load(f)
value_functions = dict(value_functions)
dens_merge = nx.density(merge_comp)
intervals = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1]
temp_dens = 0
for i in intervals:
if dens_merge <= i:
temp_dens = i
break
else:
continue
score_merge = value_functions[temp_dens]
sc1 = comp[1]
sc2 = max_over_comp[1]
if score_merge > sc1 and score_merge > sc2:
fin_list.append((frozenset(merge_comp.nodes()), merge_comp.graph['comp_score']))
fin_list.remove(comp)
fin_list.remove(max_over_comp)
if OS_max_ind_fin <= ind:
ind -= 1
# Otherwise: remove lower scoring complex
elif sc1 <= sc2:
fin_list.remove(comp)
else:
fin_list.remove(max_over_comp)
if OS_max_ind_fin > ind:
ind += 1
else:
ind += 1
logging_info("No. of changes = %s", str(n_changes))
logging_info("Finished filtering complexes.")
return fin_list, score_merge
def postprocess(pred_comp_list, modelfname, scaler, inputs, G, prot_list, train_prot_list, test_prot_list):
with open(modelfname, 'rb') as f:
model = pickle_load(f)
if len(pred_comp_list) == 0:
return pred_comp_list
# Removing complexes with only two nodes
# Finding unique complexes
fin_list_graphs = set([(comp, score) for comp, score in pred_comp_list if len(comp) > 2])
logging_info("Finished sampling complexes.")
# Filtering complexes with high overlap with bigger complexes
if inputs['overlap_method'] == 'testing_qi_0.3':
fin_list_graphs = merge_filter_overlapped_score_qi(fin_list_graphs, inputs, G)
else:
fin_list_graphs = merge_filter_overlapped_score(fin_list_graphs, inputs, G)
# Sort by scores
fin_list_graphs = sorted(fin_list_graphs, key=lambda x: x[1], reverse=True)
logging_info("Writing Predicted complexes.")
out_comp_nm = inputs['dir_nm'] + inputs['out_comp_nm']
if inputs['dir_nm'] == "humap":
convert2names_wscores(fin_list_graphs, out_comp_nm + '_pred_names.out',
G, out_comp_nm + '_pred_edges_names.out')
tot_pred_edges_unique_max_comp_prob = {}
with open(out_comp_nm + '_pred.out', "w") as fn:
with open(out_comp_nm + '_pred_edges.out', "wb") as f_edges:
fn_write = fn.write
f_edges_write = f_edges.write
for index in range(len(fin_list_graphs)):
tmp_graph_nodes = fin_list_graphs[index][0]
tmp_score = fin_list_graphs[index][1]
for node in tmp_graph_nodes:
fn_write("%s " % node)
fn_write("%.3f" % tmp_score)
tmp_graph = G.subgraph(tmp_graph_nodes)
nx_write_weighted_edgelist(tmp_graph, f_edges)
tmp_graph_edges = tmp_graph.edges()
for edge in tmp_graph_edges:
edge_set = frozenset([edge[0], edge[1]])
if edge_set in tot_pred_edges_unique_max_comp_prob:
tot_pred_edges_unique_max_comp_prob[edge_set] = max(tot_pred_edges_unique_max_comp_prob[edge_set], tmp_score)
else:
tot_pred_edges_unique_max_comp_prob[edge_set] = tmp_score
fn_write("\n")
f_edges_write("\n".encode())
with open(out_comp_nm + '_tot_pred_edges_unique_max_comp_prob.out', "w") as f:
with open(out_comp_nm + '_tot_pred_edges_unique_max_comp_prob_inKnown.out', "w") as f_inKnown:
with open(out_comp_nm + '_tot_pred_edges_unique_max_comp_prob_inKnown_train.out', "w") as f_inKnown_train:
with open(out_comp_nm + '_tot_pred_edges_unique_max_comp_prob_inKnown_test.out', "w") as f_inKnown_test:
for edge_key in tot_pred_edges_unique_max_comp_prob:
edge = list(edge_key)
edge_score = tot_pred_edges_unique_max_comp_prob[edge_key]
f.write(edge[0] + "\t" + edge[1] + "\t" + "%.3f" % edge_score + "\n")
if edge[0] in prot_list and edge[1] in prot_list:
f_inKnown.write(edge[0] + "\t" + edge[1] + "\t" + "%.3f" % edge_score + "\n")
if edge[0] in train_prot_list and edge[1] in train_prot_list:
f_inKnown_train.write(edge[0] + "\t" + edge[1] + "\t" + "%.3f" % edge_score + "\n")
if edge[0] in test_prot_list and edge[1] in test_prot_list:
f_inKnown_test.write(edge[0] + "\t" + edge[1] + "\t" + "%.3f" % edge_score + "\n")
logging_info("Finished writing Predicted complexes.")
return fin_list_graphs