forked from uta-smile/DeepAttnMISL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
DeepAttnMISL_model.py
96 lines (69 loc) · 2.43 KB
/
DeepAttnMISL_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""
Model definition of DeepAttnMISL
If this work is useful for your research, please consider to cite our papers:
[1] "Whole Slide Images based Cancer Survival Prediction using Attention Guided Deep Multiple Instance Learning Networks"
Jiawen Yao, XinliangZhu, Jitendra Jonnagaddala, NicholasHawkins, Junzhou Huang,
Medical Image Analysis, Available online 19 July 2020, 101789
[2] "Deep Multi-instance Learning for Survival Prediction from Whole Slide Images", In MICCAI 2019
"""
import torch.nn as nn
import torch
class DeepAttnMIL_Surv(nn.Module):
"""
Deep AttnMISL Model definition
"""
def __init__(self, cluster_num):
super(DeepAttnMIL_Surv, self).__init__()
self.embedding_net = nn.Sequential(nn.Conv2d(4096, 64, 1),
nn.ReLU(),
nn.AdaptiveAvgPool2d((1,1))
)
self.attention = nn.Sequential(
nn.Linear(64, 32), # V
nn.Tanh(),
nn.Linear(32, 1) # W
)
self.fc6 = nn.Sequential(
nn.Linear(64, 32),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(32,1)
)
self.cluster_num = cluster_num
def masked_softmax(self, x, mask=None):
"""
Performs masked softmax, as simply masking post-softmax can be
inaccurate
:param x: [batch_size, num_items]
:param mask: [batch_size, num_items]
:return:
"""
if mask is not None:
mask = mask.float()
if mask is not None:
x_masked = x * mask + (1 - 1 / (mask+1e-5))
else:
x_masked = x
x_max = x_masked.max(1)[0]
x_exp = (x - x_max.unsqueeze(-1)).exp()
if mask is not None:
x_exp = x_exp * mask.float()
return x_exp / x_exp.sum(1).unsqueeze(-1)
def forward(self, x, mask):
" x is a tensor list"
res = []
for i in range(self.cluster_num):
hh = x[i]
output = self.embedding_net(hh)
output = output.view(output.size()[0], -1)
res.append(output)
h = torch.cat(res)
b = h.size(0)
c = h.size(1)
h = h.view(b, c)
A = self.attention(h)
A = torch.transpose(A, 1, 0) # KxN
A = self.masked_softmax(A, mask)
M = torch.mm(A, h) # KxL
Y_pred = self.fc6(M)
return Y_pred