From 03f52e096be63afc602676648bf60662e7c8e34b Mon Sep 17 00:00:00 2001 From: Anant Gupta Date: Tue, 30 Oct 2018 14:38:27 -0400 Subject: [PATCH] BahdanauAttnDecoderRNN fix: syntax fixes, removed unused variable max_length handled dot product for batch-size=1 tested module --- seq2seq-translation/seq2seq-translation.ipynb | 188 +++++++++--------- 1 file changed, 96 insertions(+), 92 deletions(-) diff --git a/seq2seq-translation/seq2seq-translation.ipynb b/seq2seq-translation/seq2seq-translation.ipynb index a96cede..c22c30f 100644 --- a/seq2seq-translation/seq2seq-translation.ipynb +++ b/seq2seq-translation/seq2seq-translation.ipynb @@ -82,10 +82,8 @@ }, { "cell_type": "code", - "execution_count": 1, - "metadata": { - "collapsed": true - }, + "execution_count": 13, + "metadata": {}, "outputs": [], "source": [ "import unicodedata\n", @@ -111,10 +109,8 @@ }, { "cell_type": "code", - "execution_count": 2, - "metadata": { - "collapsed": true - }, + "execution_count": 14, + "metadata": {}, "outputs": [], "source": [ "USE_CUDA = True" @@ -198,9 +194,7 @@ { "cell_type": "code", "execution_count": 4, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [], "source": [ "# Turn a Unicode string to plain ASCII, thanks to http://stackoverflow.com/a/518232/2809427\n", @@ -266,9 +260,7 @@ { "cell_type": "code", "execution_count": 6, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [], "source": [ "MAX_LENGTH = 10\n", @@ -302,9 +294,7 @@ { "cell_type": "code", "execution_count": 7, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -357,9 +347,7 @@ { "cell_type": "code", "execution_count": 8, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [], "source": [ "# Return a list of indexes, one for each word in the sentence\n", @@ -400,10 +388,8 @@ }, { "cell_type": "code", - "execution_count": 9, - "metadata": { - "collapsed": false - }, + "execution_count": 15, + "metadata": {}, "outputs": [], "source": [ "class EncoderRNN(nn.Module):\n", @@ -504,29 +490,26 @@ }, { "cell_type": "code", - "execution_count": 10, - "metadata": { - "collapsed": true - }, + "execution_count": 27, + "metadata": {}, "outputs": [], "source": [ "class BahdanauAttnDecoderRNN(nn.Module):\n", " def __init__(self, hidden_size, output_size, n_layers=1, dropout_p=0.1):\n", - " super(AttnDecoderRNN, self).__init__()\n", + " super(BahdanauAttnDecoderRNN, self).__init__()\n", " \n", " # Define parameters\n", " self.hidden_size = hidden_size\n", " self.output_size = output_size\n", " self.n_layers = n_layers\n", " self.dropout_p = dropout_p\n", - " self.max_length = max_length\n", " \n", " # Define layers\n", " self.embedding = nn.Embedding(output_size, hidden_size)\n", " self.dropout = nn.Dropout(dropout_p)\n", - " self.attn = GeneralAttn(hidden_size)\n", + " self.attn = Attn(\"general\", hidden_size)\n", " self.gru = nn.GRU(hidden_size * 2, hidden_size, n_layers, dropout=dropout_p)\n", - " self.out = nn.Linear(hidden_size, output_size)\n", + " self.out = nn.Linear(hidden_size * 2, output_size)\n", " \n", " def forward(self, word_input, last_hidden, encoder_outputs):\n", " # Note that we will only be running forward for a single decoder time step, but will use all encoder outputs\n", @@ -545,10 +528,11 @@ " \n", " # Final output layer\n", " output = output.squeeze(0) # B x N\n", + " context = context.squeeze(0) # B x N \n", " output = F.log_softmax(self.out(torch.cat((output, context), 1)))\n", " \n", " # Return final output, hidden state, and attention weights (for visualization)\n", - " return output, hidden, attn_weights" + " return output, context, hidden, attn_weights" ] }, { @@ -586,14 +570,12 @@ }, { "cell_type": "code", - "execution_count": 11, - "metadata": { - "collapsed": true - }, + "execution_count": 20, + "metadata": {}, "outputs": [], "source": [ "class Attn(nn.Module):\n", - " def __init__(self, method, hidden_size, max_length=MAX_LENGTH):\n", + " def __init__(self, method, hidden_size):\n", " super(Attn, self).__init__()\n", " \n", " self.method = method\n", @@ -621,14 +603,18 @@ " return F.softmax(attn_energies).unsqueeze(0).unsqueeze(0)\n", " \n", " def score(self, hidden, encoder_output):\n", + " # Using torch.bmm to perform dot product batch-wise\n", + " hidden = hidden.unsqueeze(1) # b x 1 x hidden_size\n", " \n", " if self.method == 'dot':\n", - " energy = hidden.dot(encoder_output)\n", + " encoder_output = encoder_output.unsqueeze(2) # b x hidden_size x 1 \n", + " energy = torch.bmm(hidden, encoder_output)\n", " return energy\n", " \n", " elif self.method == 'general':\n", " energy = self.attn(encoder_output)\n", - " energy = hidden.dot(energy)\n", + " energy = energy.unsqueeze(2) # b x hidden_size x 1 \n", + " energy = torch.bmm(hidden, energy)\n", " return energy\n", " \n", " elif self.method == 'concat':\n", @@ -646,10 +632,8 @@ }, { "cell_type": "code", - "execution_count": 12, - "metadata": { - "collapsed": false - }, + "execution_count": 21, + "metadata": {}, "outputs": [], "source": [ "class AttnDecoderRNN(nn.Module):\n", @@ -706,9 +690,8 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 30, "metadata": { - "collapsed": false, "scrolled": false }, "outputs": [ @@ -716,17 +699,34 @@ "name": "stdout", "output_type": "stream", "text": [ - "EncoderRNN (\n", + "EncoderRNN(\n", " (embedding): Embedding(10, 10)\n", " (gru): GRU(10, 10, num_layers=2)\n", ")\n", - "AttnDecoderRNN (\n", + "Luong Decoder\n", + "AttnDecoderRNN(\n", " (embedding): Embedding(10, 10)\n", " (gru): GRU(20, 10, num_layers=2, dropout=0.1)\n", - " (out): Linear (20 -> 10)\n", - " (attn): Attn (\n", - " (attn): Linear (10 -> 10)\n", + " (out): Linear(in_features=20, out_features=10, bias=True)\n", + " (attn): Attn(\n", + " (attn): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + ")\n", + "torch.Size([1, 10]) torch.Size([2, 1, 10]) torch.Size([1, 1, 3])\n", + "torch.Size([1, 10]) torch.Size([2, 1, 10]) torch.Size([1, 1, 3])\n", + "torch.Size([1, 10]) torch.Size([2, 1, 10]) torch.Size([1, 1, 3])\n", + "\n", + "\n", + "\n", + "Bahdanau Decoder\n", + "BahdanauAttnDecoderRNN(\n", + " (embedding): Embedding(10, 10)\n", + " (dropout): Dropout(p=0.1)\n", + " (attn): Attn(\n", + " (attn): Linear(in_features=10, out_features=10, bias=True)\n", " )\n", + " (gru): GRU(20, 10, num_layers=2, dropout=0.1)\n", + " (out): Linear(in_features=20, out_features=10, bias=True)\n", ")\n", "torch.Size([1, 10]) torch.Size([2, 1, 10]) torch.Size([1, 1, 3])\n", "torch.Size([1, 10]) torch.Size([2, 1, 10]) torch.Size([1, 1, 3])\n", @@ -738,6 +738,8 @@ "encoder_test = EncoderRNN(10, 10, 2)\n", "decoder_test = AttnDecoderRNN('general', 10, 10, 2)\n", "print(encoder_test)\n", + "\n", + "print(\"Luong Decoder\")\n", "print(decoder_test)\n", "\n", "encoder_hidden = encoder_test.init_hidden()\n", @@ -760,7 +762,35 @@ "for i in range(3):\n", " decoder_output, decoder_context, decoder_hidden, decoder_attn = decoder_test(word_inputs[i], decoder_context, decoder_hidden, encoder_outputs)\n", " print(decoder_output.size(), decoder_hidden.size(), decoder_attn.size())\n", - " decoder_attns[0, i] = decoder_attn.squeeze(0).cpu().data" + " decoder_attns[0, i] = decoder_attn.squeeze(0).cpu().data\n", + " \n", + "decoder_test = BahdanauAttnDecoderRNN(10, 10, 2) \n", + "print(\"\\n\\n\\nBahdanau Decoder\")\n", + "print(decoder_test)\n", + "\n", + "encoder_hidden = encoder_test.init_hidden()\n", + "word_input = Variable(torch.LongTensor([1, 2, 3]))\n", + "if USE_CUDA:\n", + " encoder_test.cuda()\n", + " word_input = word_input.cuda()\n", + "encoder_outputs, encoder_hidden = encoder_test(word_input, encoder_hidden)\n", + "\n", + "word_inputs = Variable(torch.LongTensor([1, 2, 3]))\n", + "decoder_attns = torch.zeros(1, 3, 3)\n", + "decoder_hidden = encoder_hidden\n", + "decoder_context = Variable(torch.zeros(1, decoder_test.hidden_size))\n", + "\n", + "if USE_CUDA:\n", + " decoder_test.cuda()\n", + " word_inputs = word_inputs.cuda()\n", + " decoder_context = decoder_context.cuda()\n", + "\n", + "for i in range(3):\n", + " decoder_output, decoder_context, decoder_hidden, decoder_attn = decoder_test(word_inputs[i], decoder_hidden, encoder_outputs)\n", + " print(decoder_output.size(), decoder_hidden.size(), decoder_attn.size())\n", + " decoder_attns[0, i] = decoder_attn.squeeze(0).cpu().data\n", + " \n", + " " ] }, { @@ -785,9 +815,7 @@ { "cell_type": "code", "execution_count": 14, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [], "source": [ "teacher_forcing_ratio = 0.5\n", @@ -862,9 +890,7 @@ { "cell_type": "code", "execution_count": 15, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [], "source": [ "def as_minutes(s):\n", @@ -894,9 +920,7 @@ { "cell_type": "code", "execution_count": 16, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [], "source": [ "attn_model = 'general'\n", @@ -930,9 +954,7 @@ { "cell_type": "code", "execution_count": 17, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -968,7 +990,6 @@ "cell_type": "code", "execution_count": 18, "metadata": { - "collapsed": false, "scrolled": false }, "outputs": [ @@ -1121,9 +1142,7 @@ { "cell_type": "code", "execution_count": 19, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [ { "data": { @@ -1173,9 +1192,7 @@ { "cell_type": "code", "execution_count": 20, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [], "source": [ "def evaluate(sentence, max_length=MAX_LENGTH):\n", @@ -1250,7 +1267,6 @@ "cell_type": "code", "execution_count": 22, "metadata": { - "collapsed": false, "scrolled": false }, "outputs": [ @@ -1283,9 +1299,7 @@ { "cell_type": "code", "execution_count": 24, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [ { "data": { @@ -1313,9 +1327,7 @@ { "cell_type": "code", "execution_count": 25, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [], "source": [ "def show_attention(input_sentence, output_words, attentions):\n", @@ -1346,9 +1358,7 @@ { "cell_type": "code", "execution_count": 26, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1376,9 +1386,7 @@ { "cell_type": "code", "execution_count": 27, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1406,9 +1414,7 @@ { "cell_type": "code", "execution_count": 28, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1436,9 +1442,7 @@ { "cell_type": "code", "execution_count": 29, - "metadata": { - "collapsed": false - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1497,7 +1501,7 @@ "metadata": { "anaconda-cloud": {}, "kernelspec": { - "display_name": "Python [default]", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -1511,7 +1515,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.5.2" + "version": "3.6.5" } }, "nbformat": 4,