-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNetwork.py
47 lines (34 loc) · 1.53 KB
/
Network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
class classifier(nn.Module):
#define all the layers used in model
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers,
bidirectional, dropout):
#Constructor
super().__init__()
#embedding layer
self.embedding = nn.Embedding(vocab_size, embedding_dim)
#lstm layer
self.lstm = nn.LSTM(embedding_dim,
hidden_dim,
num_layers=n_layers,
bidirectional=bidirectional,
dropout=dropout,
batch_first=True)
#dense layer
self.fc = nn.Linear(hidden_dim * 2, output_dim)
def forward(self, text, text_lengths):
#text = [batch size,sent_length]
embedded = self.embedding(text)
#embedded = [batch size, sent_len, emb dim]
#packed sequence
#packed_embedded = pack_padded_sequence(embedded, text_lengths, batch_first=True, enforce_sorted=False)
packed_output, (hidden, cell) = self.lstm(embedded)
#hidden = [batch size, num layers * num directions,hid dim]
#cell = [batch size, num layers * num directions,hid dim]
#concat the final forward and backward hidden state
hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)
#hidden = [batch size, hid dim * num directions]
dense_outputs=self.fc(hidden)
return dense_outputs