From 8f8677ad0e86d486c8359199513ed3889d74ebac Mon Sep 17 00:00:00 2001 From: nkoumchatzky Date: Tue, 30 May 2017 18:48:06 -0400 Subject: [PATCH] C implementation of LSTM -- REVIEWABLE * New "sparse/size" representation * Full LSTM in C * VSeqLSTM to wrap this data representation + C implementation * Augmentation of the VariableLength decorator with this data representation from an array of tensors * unit tests * speed tests --- CMakeLists.txt | 11 + LookupTableMaskZero.lua | 6 +- Module.lua | 9 + SeqLSTM.lua | 2 +- THRNN.lua | 147 ++++++ VSeqLSTM.lua | 261 +++++++++++ VariableLength.lua | 285 ++++++++++-- init.lua | 4 +- lib/CMakeLists.txt | 5 + lib/THRNN/CMakeLists.txt | 83 ++++ lib/THRNN/THRNN.h | 27 ++ lib/THRNN/generic/LSTM.c | 944 ++++++++++++++++++++++++++++++++++++++ lib/THRNN/generic/THRNN.h | 64 +++ lib/THRNN/init.c | 4 + test/bigtest.lua | 152 +++++- test/test.lua | 145 +++++- 16 files changed, 2095 insertions(+), 54 deletions(-) create mode 100644 THRNN.lua create mode 100644 VSeqLSTM.lua create mode 100644 lib/CMakeLists.txt create mode 100644 lib/THRNN/CMakeLists.txt create mode 100644 lib/THRNN/THRNN.h create mode 100644 lib/THRNN/generic/LSTM.c create mode 100644 lib/THRNN/generic/THRNN.h create mode 100644 lib/THRNN/init.c diff --git a/CMakeLists.txt b/CMakeLists.txt index 74efbc0..de87ccd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,12 +5,20 @@ CMAKE_POLICY(VERSION 2.6) FIND_PACKAGE(Torch REQUIRED) +ADD_SUBDIRECTORY(lib) + SET(BUILD_STATIC YES) # makes sure static targets are enabled in ADD_TORCH_PACKAGE SET(CMAKE_C_FLAGS "--std=c99 -pedantic -Werror -Wall -Wextra -Wno-unused-function -D_GNU_SOURCE ${CMAKE_C_FLAGS}") SET(src init.c ) + +FILE(STRINGS lib/THRNN/generic/THRNN.h THRNN_headers NEWLINE_CONSUME) +FILE(WRITE THRNN_h.lua "return [[") +FILE(APPEND THRNN_h.lua ${THRNN_headers}) +FILE(APPEND THRNN_h.lua "]]") + SET(luasrc init.lua AbstractRecurrent.lua @@ -36,6 +44,7 @@ SET(luasrc SeqBLSTM.lua SeqGRU.lua SeqLSTM.lua + VSeqLSTM.lua Sequencer.lua SequencerCriterion.lua test/bigtest.lua @@ -75,6 +84,8 @@ SET(luasrc deprecated/FastLSTM.lua deprecated/GRU.lua deprecated/LSTM.lua + THRNN.lua + THRNN_h.lua ) ADD_TORCH_PACKAGE(rnn "${src}" "${luasrc}" "An RNN library for Torch") diff --git a/LookupTableMaskZero.lua b/LookupTableMaskZero.lua index 9721175..63f0efd 100644 --- a/LookupTableMaskZero.lua +++ b/LookupTableMaskZero.lua @@ -5,17 +5,17 @@ function LookupTableMaskZero:__init(nIndex, nOutput) end function LookupTableMaskZero:updateOutput(input) - self.weight[1]:zero() + self.weight[1]:zero() if self.__input and (torch.type(self.__input) ~= torch.type(input)) then self.__input = nil -- fixes old casting bug end self.__input = self.__input or input.new() self.__input:resizeAs(input):add(input, 1) - return parent.updateOutput(self, self.__input) + return parent.updateOutput(self, self.__input) end function LookupTableMaskZero:accGradParameters(input, gradOutput, scale) - parent.accGradParameters(self, self.__input, gradOutput, scale) + parent.accGradParameters(self, self.__input, gradOutput, scale) end function LookupTableMaskZero:type(type, cache) diff --git a/Module.lua b/Module.lua index c613959..0fc7c3a 100644 --- a/Module.lua +++ b/Module.lua @@ -38,6 +38,15 @@ function Module:setZeroMask(zeroMask) end end +function Module:setContext(context) + if self.modules then + for i, module in ipairs(self.modules) do + module:setContext(context) + end + end + self.__context = context +end + function Module:stepClone(shareParams, shareGradParams) return self:sharedClone(shareParams, shareGradParams, true) end diff --git a/SeqLSTM.lua b/SeqLSTM.lua index 1ba27ac..d622009 100644 --- a/SeqLSTM.lua +++ b/SeqLSTM.lua @@ -457,4 +457,4 @@ function SeqLSTM:toRecLSTM() end return lstm -end \ No newline at end of file +end diff --git a/THRNN.lua b/THRNN.lua new file mode 100644 index 0000000..cc97979 --- /dev/null +++ b/THRNN.lua @@ -0,0 +1,147 @@ +local ffi = require 'ffi' + +local THRNN = {} + + +local generic_THRNN_h = require 'rnn.THRNN_h' +-- strip all lines starting with # +-- to remove preprocessor directives originally present +-- in THRNN.h +generic_THRNN_h = generic_THRNN_h:gsub("\n#[^\n]*", "") +generic_THRNN_h = generic_THRNN_h:gsub("^#[^\n]*\n", "") + +-- THGenerator struct declaration copied from torch7/lib/TH/THRandom.h +local base_declarations = [[ +typedef void THRNNState; + +typedef struct { + unsigned long the_initial_seed; + int left; + int seeded; + unsigned long next; + unsigned long state[624]; /* the array for the state vector 624 = _MERSENNE_STATE_N */ + double normal_x; + double normal_y; + double normal_rho; + int normal_is_valid; +} THGenerator; +]] + +-- polyfill for LUA 5.1 +if not package.searchpath then + local sep = package.config:sub(1,1) + function package.searchpath(mod, path) + mod = mod:gsub('%.', sep) + for m in path:gmatch('[^;]+') do + local nm = m:gsub('?', mod) + local f = io.open(nm, 'r') + if f then + f:close() + return nm + end + end + end +end + +-- load libTHRNN +THRNN.C = ffi.load(package.searchpath('libTHRNN', package.cpath)) + +ffi.cdef(base_declarations) + +-- expand macros, allow to use original lines from lib/THRNN/generic/THRNN.h +local preprocessed = string.gsub(generic_THRNN_h, 'TH_API void THRNN_%(([%a%d_]+)%)', 'void THRNN_TYPE%1') + +local replacements = +{ + { + ['TYPE'] = 'Double', + ['real'] = 'double', + ['THTensor'] = 'THDoubleTensor', + ['THIndexTensor'] = 'THLongTensor', + ['THIntegerTensor'] = 'THIntTensor', + ['THIndex_t'] = 'long', + ['THInteger_t'] = 'int' + }, + { + ['TYPE'] = 'Float', + ['real'] = 'float', + ['THTensor'] = 'THFloatTensor', + ['THIndexTensor'] = 'THLongTensor', + ['THIntegerTensor'] = 'THIntTensor', + ['THIndex_t'] = 'long', + ['THInteger_t'] = 'int' + } +} + +-- gsub(s, 'real', 'float') changes accreal to accfloat. +-- typedef accfloat ahead of time. +ffi.cdef("typedef double accfloat;") +-- gsub(s, 'real', 'double') changes accreal to accfloat. +-- typedef accdouble ahead of time +ffi.cdef("typedef double accdouble;") + +for i=1,#replacements do + local r = replacements[i] + local s = preprocessed + for k,v in pairs(r) do + s = string.gsub(s, k, v) + end + ffi.cdef(s) +end + +THRNN.NULL = ffi.NULL or nil + +function THRNN.getState() + return ffi.NULL or nil +end + +function THRNN.optionalTensor(t) + return t and t:cdata() or THRNN.NULL +end + +local function extract_function_names(s) + local t = {} + for n in string.gmatch(s, 'TH_API void THRNN_%(([%a%d_]+)%)') do + t[#t+1] = n + end + return t +end + +function THRNN.bind(lib, base_names, type_name, state_getter) + local ftable = {} + local prefix = 'THRNN_' .. type_name + for i,n in ipairs(base_names) do + -- use pcall since some libs might not support all functions (e.g. cunn) + local ok,v = pcall(function() return lib[prefix .. n] end) + if ok then + ftable[n] = function(...) v(state_getter(), ...) end -- implicitely add state + else + print('not found: ' .. prefix .. n .. v) + end + end + return ftable +end + +-- build function table +local function_names = extract_function_names(generic_THRNN_h) + +THRNN.kernels = {} +THRNN.kernels['torch.FloatTensor'] = THRNN.bind(THRNN.C, function_names, 'Float', THRNN.getState) +THRNN.kernels['torch.DoubleTensor'] = THRNN.bind(THRNN.C, function_names, 'Double', THRNN.getState) + +torch.getmetatable('torch.FloatTensor').THRNN = THRNN.kernels['torch.FloatTensor'] +torch.getmetatable('torch.DoubleTensor').THRNN = THRNN.kernels['torch.DoubleTensor'] + +function THRNN.runKernel(f, type, ...) + local ftable = THRNN.kernels[type] + if not ftable then + error('Unsupported tensor type: '..type) + end + local f = ftable[f] + if not f then + error(string.format("Function '%s' not found for tensor type '%s'.", f, type)) + end + f(...) +end + +return THRNN diff --git a/VSeqLSTM.lua b/VSeqLSTM.lua new file mode 100644 index 0000000..6a54e89 --- /dev/null +++ b/VSeqLSTM.lua @@ -0,0 +1,261 @@ +local VSeqLSTM, parent = torch.class('nn.VSeqLSTM', 'nn.Module') + +function VSeqLSTM:__init(inputsize, outputsize) + parent.__init(self) + + self.inputsize, self.outputsize, self.hiddensize = inputsize, outputsize, outputsize + + self.weight = torch.Tensor(inputsize+outputsize, 4 * outputsize) + self.gradWeight = torch.Tensor(inputsize+outputsize, 4 * outputsize) + + self.bias = torch.Tensor(4 * outputsize) + self.gradBias = torch.Tensor(4 * outputsize):zero() + self:reset() + + self.cell = torch.Tensor() + self.buffer = torch.Tensor() + self.gradInputBuffer = torch.Tensor() + self.weightBuffer = torch.Tensor() + + self.h0 = torch.Tensor() + self.c0 = torch.Tensor() + + self._remember = 'neither' + + -- set this to true for variable length sequences that seperate + -- independent sequences with a step of zeros (a tensor of size D) + self.maskzero = false + self.v2 = true +end + +VSeqLSTM.reset = nn.StepLSTM.reset + +--[[ +Input: +- c0: Initial cell state, (N, H) +- h0: Initial hidden state, (N, H) +- x: Input sequence, ((T1+T2+...+Tn, D) + +Output: +- h: Sequence of hidden states, (T, N, H) +--]] + +function VSeqLSTM:updateOutput(input) + self.recompute_backward = true + -- Expect a { torch.Tensor, torch.LongTensor } + -- where the first tensor is the concatenated array of input values, + -- following the 'timesteps-first' format + -- ans the second is the array of decreasing batch sizes + local linput, sizes + if torch.isTensor(input) then + assert(self.__context and self.__context.sizes) + linput = input + sizes = self.__context.sizes + elseif type(input) == 'table' and input[1] and input[2] then + linput = input[1] + sizes = input[2] + else + error('Cannot recognize input') + end + local batchsize = sizes[1] + local inputsize, outputsize = self.inputsize, self.outputsize + + + -- remember previous state? + local remember = self:hasMemory() + + local c0 = self.c0 + if (c0:nElement() ~= batchsize * outputsize) or not remember then + c0:resize(batchsize, outputsize):zero() + elseif remember then + assert(self.cell:size(2) == batchsize, 'batch sizes must be constant to remember states') + c0:copy(self.cell[self.cell:size(1)]) + end + + local h0 = self.h0 + if (h0:nElement() ~= batchsize * outputsize) or not remember then + h0:resize(batchsize, outputsize):zero() + elseif remember then + assert(self.output:size(2) == batchsize, 'batch sizes must be the same to remember states') + h0:copy(self.output[self.output:size(1)]) + end + + local h, c = self.output, self.cell + + linput.THRNN.LSTM_updateOutput( + linput:cdata(), + c0:cdata(), + h0:cdata(), + sizes:cdata(), + self.output:cdata(), + self.weight:cdata(), + self.bias:cdata(), + self.buffer:cdata(), + self.inputsize, + self.outputsize, + 1, + 0, + 0) + + + return self.output +end + +function VSeqLSTM:backward(input, gradOutput, scale) + self.recompute_backward = false + scale = scale or 1.0 + assert(scale == 1.0, 'must have scale=1') + + local linput, sizes + if torch.isTensor(input) then + assert(self.__context and self.__context.sizes) + linput = input + sizes = self.__context.sizes + elseif type(input) == 'table' and input[1] and input[2] then + linput = input[1] + sizes = input[2] + else + error('Cannot recognize input') + end + local batchsize = sizes[1] + local inputsize, outputsize = self.inputsize, self.outputsize + + if not self.grad_hT then + self.gradInputH = self.gradInputH or self.h0.new() + self.gradInputH:resize(self.h0:size()):zero() + else + self.gradInputH = self.grad_hT + end + + if not self.grad_cT then + self.gradInputC = self.gradInputC or self.c0.new() + self.gradInputC:resize(self.c0:size()):zero() + else + self.gradInputC = self.grad_cT + end + + + linput.THRNN.LSTM_backward( + linput:cdata(), + self.c0:cdata(), + self.h0:cdata(), + sizes:cdata(), + gradOutput:cdata(), + self.gradInput:cdata(), + self.gradInputH:cdata(), + self.gradInputC:cdata(), + self.weight:cdata(), + self.bias:cdata(), + self.buffer:cdata(), + self.weightBuffer:cdata(), + self.gradInputBuffer:cdata(), + self.gradWeight:cdata(), + self.gradBias:cdata(), + self.output:cdata(), + scale, + 0, + inputsize, + outputsize, + 1, + 0) + + return self.gradInput +end + +function VSeqLSTM:clearState() + self.cell:set() + self.buffer:set() + self.weightBuffer:set() + self.gradInputBuffer:set() + self.c0:set() + self.h0:set() + + self.output:set() + self.gradInput:set() + self.grad_hidden = nil + self.hidden = nil + + self.zeroMask = nil +end + +function VSeqLSTM:updateGradInput(input, gradOutput) + if self.recompute_backward then + self:backward(input, gradOutput, 1.0) + end + return self.gradInput +end + +function VSeqLSTM:accGradParameters(input, gradOutput, scale) + if self.recompute_backward then + self:backward(input, gradOutput, scale) + end +end + +function VSeqLSTM:forget() + self.c0:resize(0) + self.h0:resize(0) +end + +function VSeqLSTM:type(type, ...) + self:clearState() + return parent.type(self, type, ...) +end + +-- Toggle to feed long sequences using multiple forwards. +-- 'eval' only affects evaluation (recommended for RNNs) +-- 'train' only affects training +-- 'neither' affects neither training nor evaluation +-- 'both' affects both training and evaluation (recommended for LSTMs) +VSeqLSTM.remember = nn.AbstractSequencer.remember +VSeqLSTM.hasMemory = nn.AbstractSequencer.hasMemory + +function VSeqLSTM:training() + if self.train == false then + -- forget at the start of each training + self:forget() + end + parent.training(self) +end + +function VSeqLSTM:evaluate() + if self.train ~= false then + -- forget at the start of each evaluation + self:forget() + end + parent.evaluate(self) + assert(self.train == false) +end + +VSeqLSTM.maskZero = nn.StepLSTM.maskZero +VSeqLSTM.setZeroMask = nn.MaskZero.setZeroMask +VSeqLSTM.__tostring__ = nn.StepLSTM.__tostring__ + +function VSeqLSTM:parameters() + return {self.weight, self.bias}, {self.gradWeight, self.gradBias} +end + +function VSeqLSTM:setStartState(hiddenState) + self.h0:resizeAs(hiddenState[1]):copy(hiddenState[1]) + self.c0:resizeAs(hiddenState[2]):copy(hiddenState[2]) +end + +function VSeqLSTM:setHiddenState(step, hiddenState) + if step == 0 then + self:setStartState(hiddenState) + else + error"NotImplemented" + end +end + +function VSeqLSTM:getHiddenState() + error"NotImplemented" +end + +function VSeqLSTM:setGradHiddenState() + error"NotImplemented" +end + +function VSeqLSTM:getGradHiddenState() + error"NotImplemented" +end + diff --git a/VariableLength.lua b/VariableLength.lua index 0261b9b..f8672c7 100644 --- a/VariableLength.lua +++ b/VariableLength.lua @@ -1,9 +1,46 @@ local VariableLength, parent = torch.class("nn.VariableLength", "nn.Decorator") -function VariableLength:__init(module, lastOnly) - parent.__init(self, assert(module:maskZero())) +-- This modules deals with variable lengths representations +-- of input data. It takes the simplest representation of variable length sequences: +-- +-- { +-- torch.Tensor(T1, [...]), +-- torch.Tensor(T2, [...]), +-- ..., +-- torch.Tensor(TN, [...]) +-- } +-- +-- and turns it either into an equivalent (in terms of amount of info) +-- "sparse/size" representation: +-- +-- "sparse": torch.Tensor(T1+...+TN, [...]) +-- "size": torch.LongTensor({T1,T2,...,TN}) +-- +-- where "sparse" is the direct concatenation of the input array of tensors +-- and "size" is the 1D tensor containing the sequence lengths +-- +-- or into an equivalent (in terms of amount of info) +-- "dense/masked" representation: +-- +-- "dense": torch.Tensor(max_i{T}, N, [...]) +-- "mask": torch.ByteTensor(max_i{T}, N) +-- where max_i{T} is the maximum sequence length, +-- "dense" is the rectangular version of the input data, +-- "mask" indicates where a sequence ends (1) or is valid (0) +function VariableLength:__init(module, lastOnly, sparse) + parent.__init(self, module) -- only extract the last element of each sequence self.lastOnly = lastOnly -- defaults to false + if sparse or torch.type(module) == 'nn.VSeqLSTM' then + self.sparse = true + self.sizes = torch.LongTensor() + self.output = torch.Tensor() + self.mapping = torch.LongTensor() + self.sorted_by_batch_sizes = torch.LongTensor() + self.sorted_by_time_sizes = torch.LongTensor() + self.sorted_by_time_indices = torch.LongTensor() + self.cinput_bis = torch.Tensor() + end self.gradInput = {} end @@ -13,53 +50,231 @@ function VariableLength:updateOutput(input) assert(torch.isTensor(input[1])) local batchSize = #input - self._input = self._input or input[1].new() - -- mask is a binary tensor with 1 where self._input is zero (between sequence zero-mask) - self._mask = self._mask or torch.ByteTensor() + if self.sparse then + -- Path for "sparse/size" representation of arrays of tensors, + -- where an array of tensors is transformed into its + -- sparse/size equivalent representation, i.e.: + -- { + -- torch.Tensor(T1, ...), + -- torch.Tensor(T2, ...), + -- ..., + -- torch.Tensor(TN, ...) + -- } + -- --> + -- { torch.Tensor(T1+...+TN, ...), torch.LongTensor({T1,T2,...,TN) } + + + -- Initialize a bunch of buffers + local first_input = input[1] + self.cinput = self.cinput or first_input.new() + self.cinput = self.cinput:type(first_input:type()) + self.cgradInput = self.cgradInput or first_input.new() + self.cgradInput = self.cgradInput:type(first_input:type()) + self._input = self._input or first_input.new() + self._input = self._input:type(first_input:type()) + self._input = self._input or {} + + -- Concatenate the array of tensors, + -- extract the sequence sizes + local sm = 0 + local mx = 0 + self.cinput:cat(input, 1) + self.sizes:resize(#input) + for i=1,#input do + self.sizes[i] = input[i]:size(1) + end + + -- From the concatenated, batch-first 'self.cinput', along with 'self.sizes', + -- transpose to a time-first, sorted in decreasing order of batch size + -- to 'self._input' + self.cinput.THRNN.LSTM_bt_to_sorted_tb( + self.cinput:cdata(), + self.sizes:cdata(), + self._input:cdata(), + self.mapping:cdata(), + self.sorted_by_batch_sizes:cdata(), + self.sorted_by_time_sizes:cdata(), + self.sorted_by_time_indices:cdata(), + 0) - -- now we process input into _input. - -- indexes and mappedLengths are meta-information tables, explained below. - self.indexes, self.mappedLengths = self._input.nn.VariableLength_FromSamples(input, self._input, self._mask) + -- Set the context for all the modules, + -- containing the sorted sizes + self.context = self.context or {} + self.context.sizes = self.sorted_by_batch_sizes + self.modules[1]:setContext(self.context) - -- zero-mask the _input where mask is 1 - nn.utils.recursiveZeroMask(self._input, self._mask) - self.modules[1]:setZeroMask(self._mask) + -- Run the wrapped module + local output = self.modules[1]:updateOutput(self._input) - -- feedforward the zero-mask format through the decorated module - local output = self.modules[1]:updateOutput(self._input) + if self.lastOnly then + -- Extract the last time step of each sample. + -- self.output tensor has shape: batchSize [x outputSize] + self.output = torch.isTensor(self.output) and self.output or output.new() + self.cinput.THRNN.LSTM_sorted_tb_to_bt( + output:cdata(), + self.sizes:cdata(), + self.mapping:cdata(), + self.output:cdata(), + 1) + else + -- Reverse the transpose + self.output = {} + self._output = self._output or first_input.new() + self._output = self._output:type(first_input:type()) + output.THRNN.LSTM_sorted_tb_to_bt( + output:cdata(), + self.sizes:cdata(), + self.mapping:cdata(), + self._output:cdata(), + 0) + local runningIdx = 1 + for i=1,#input do + self.output[i] = self._output:narrow(1, runningIdx, self.sizes[i]) + runningIdx = runningIdx + self.sizes[i] + end + end - if self.lastOnly then - -- Extract the last time step of each sample. - -- self.output tensor has shape: batchSize [x outputSize] - self.output = torch.isTensor(self.output) and self.output or output.new() - self.output.nn.VariableLength_ToFinal(self.indexes, self.mappedLengths, output, self.output) else - -- This is the revese operation of everything before updateOutput - self.output = self._input.nn.VariableLength_ToSamples(self.indexes, self.mappedLengths, output) - end + -- Path for "dense/masked" representations of arrays of tensors, + -- where an array of tensors is transformed into its + -- dense/masked equivalent representation, i.e.: + -- { + -- torch.Tensor(T1, ...), + -- torch.Tensor(T2, ...), + -- ..., + -- torch.Tensor(TN, ...) + -- } + -- --> + -- { torch.Tensor(max_i{T}, N, ...), torch.ByteTensor(max_i{T}, N) } + -- where max_i{T} is the maximum sequence length + + self._input = self._input or input[1].new() + -- mask is a binary tensor with 1 where self._input is zero (between sequence zero-mask) + self._mask = self._mask or torch.ByteTensor() + + -- now we process input into _input. + -- indexes and mappedLengths are meta-information tables, explained below. + self.indexes, self.mappedLengths = self._input.nn.VariableLength_FromSamples(input, self._input, self._mask) + + -- zero-mask the _input where mask is 1 + nn.utils.recursiveZeroMask(self._input, self._mask) + self.modules[1]:setZeroMask(self._mask) + + -- feedforward the zero-mask format through the decorated module + local output = self.modules[1]:updateOutput(self._input) + if self.lastOnly then + -- Extract the last time step of each sample. + -- self.output tensor has shape: batchSize [x outputSize] + self.output = torch.isTensor(self.output) and self.output or output.new() + self.output.nn.VariableLength_ToFinal(self.indexes, self.mappedLengths, output, self.output) + else + -- This is the revese operation of everything before updateOutput + self.output = self._input.nn.VariableLength_ToSamples(self.indexes, self.mappedLengths, output) + end + end return self.output end function VariableLength:updateGradInput(input, gradOutput) - self._gradOutput = self._gradOutput or self._input.new() - if self.lastOnly then - assert(torch.isTensor(gradOutput)) - self._gradOutput.nn.VariableLength_FromFinal(self.indexes, self.mappedLengths, gradOutput, self._gradOutput) + assert(torch.type(input) == 'table') + assert(torch.isTensor(input[1])) + if self.sparse then + -- Path for "sparse/size" representation of arrays of tensors, + -- where an array of tensors is transformed into its + -- sparse/size equivalent representation, i.e.: + -- { + -- torch.Tensor(T1, ...), + -- torch.Tensor(T2, ...), + -- ..., + -- torch.Tensor(TN, ...) + -- } + -- --> + -- { torch.Tensor(T1+...+TN, ...), torch.LongTensor({T1,T2,...,TN) } + local first_input = input[1] + self._gradOutput = self._gradOutput or first_input.new() + self.cinput = self.cinput or first_input.new() + self.cinput = self.cinput:type(first_input:type()) + + if self.lastOnly then + -- Call the transposer with the "last" argument == 1 + self.cinput.THRNN.LSTM_bt_to_sorted_tb( + gradOutput:cdata(), + self.sizes:cdata(), + self._gradOutput:cdata(), + self.mapping:cdata(), + self.sorted_by_batch_sizes:cdata(), + self.sorted_by_time_sizes:cdata(), + self.sorted_by_time_indices:cdata(), + 1) + else + -- Concatenate the gradOutput, + -- and call the transposer + self.cinput:cat(gradOutput, 1) + self.cinput.THRNN.LSTM_bt_to_sorted_tb( + self.cinput:cdata(), + self.sizes:cdata(), + self._gradOutput:cdata(), + self.mapping:cdata(), + self.sorted_by_batch_sizes:cdata(), + self.sorted_by_time_sizes:cdata(), + self.sorted_by_time_indices:cdata(), + 0) + end + -- updateGradInput decorated module + self.context = self.context or {} + self.context.sizes = self.sorted_by_batch_sizes + self.modules[1]:setContext(self.context) + local gradInput = self.modules[1]:updateGradInput(self._input, self._gradOutput) + + -- Final call to the de-transposer before returning + self.gradInput = {} + self._gradInput = self._gradInput or first_input.new() + self._gradInput = self._gradInput:type(first_input:type()) + self.cinput.THRNN.LSTM_sorted_tb_to_bt( + gradInput:cdata(), + self.sizes:cdata(), + self.mapping:cdata(), + self._gradInput:cdata(), + 0) + + local runningIdx = 1 + for i=1,#input do + self.gradInput[i] = self._gradInput:narrow(1,runningIdx, self.sizes[i]) + runningIdx = runningIdx + self.sizes[i] + end else - assert(torch.type(gradOutput) == 'table') - assert(torch.isTensor(gradOutput[1])) - self.indexes, self.mappedLengths = self._gradOutput.nn.VariableLength_FromSamples(gradOutput, self._gradOutput, self._mask) - end + -- Path for "dense/masked" representations of arrays of tensors, + -- where an array of tensors is transformed into its + -- dense/masked equivalent representation, i.e.: + -- { + -- torch.Tensor(T1, ...), + -- torch.Tensor(T2, ...), + -- ..., + -- torch.Tensor(TN, ...) + -- } + -- --> + -- { torch.Tensor(max_i{T}, N, ...), torch.ByteTensor(max_i{T}, N) } + -- where max_i{T} is the maximum sequence length + self._gradOutput = self._gradOutput or self._input.new() + if self.lastOnly then + assert(torch.isTensor(gradOutput)) + self._gradOutput.nn.VariableLength_FromFinal(self.indexes, self.mappedLengths, gradOutput, self._gradOutput) + else + assert(torch.type(gradOutput) == 'table') + assert(torch.isTensor(gradOutput[1])) + self.indexes, self.mappedLengths = self._gradOutput.nn.VariableLength_FromSamples(gradOutput, self._gradOutput, self._mask) + end - -- zero-mask the _gradOutput where mask is 1 - nn.utils.recursiveZeroMask(self._gradOutput, self._mask) + -- zero-mask the _gradOutput where mask is 1 + nn.utils.recursiveZeroMask(self._gradOutput, self._mask) - -- updateGradInput decorated module - local gradInput = self.modules[1]:updateGradInput(self._input, self._gradOutput) + -- updateGradInput decorated module + local gradInput = self.modules[1]:updateGradInput(self._input, self._gradOutput) - self.gradInput = self._input.nn.VariableLength_ToSamples(self.indexes, self.mappedLengths, gradInput) + self.gradInput = self._input.nn.VariableLength_ToSamples(self.indexes, self.mappedLengths, gradInput) + end return self.gradInput end @@ -83,4 +298,4 @@ end function VariableLength:setZeroMask() error"Not Supported" -end \ No newline at end of file +end diff --git a/init.lua b/init.lua index dd222d1..bb589ac 100644 --- a/init.lua +++ b/init.lua @@ -17,6 +17,7 @@ function nn.require(packagename) end end +require('rnn.THRNN') -- c lib: require "paths" @@ -93,6 +94,7 @@ require('rnn.RecurrentAttention') -- sequencer + recurrent modules require('rnn.SeqLSTM') +require('rnn.VSeqLSTM') require('rnn.SeqGRU') require('rnn.SeqBLSTM') require('rnn.SeqBGRU') @@ -113,4 +115,4 @@ require('rnn.BiSequencerLM') -- prevent likely name conflicts nn.rnn = rnn -return rnn \ No newline at end of file +return rnn diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt new file mode 100644 index 0000000..b1e5752 --- /dev/null +++ b/lib/CMakeLists.txt @@ -0,0 +1,5 @@ +CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR) +CMAKE_POLICY(VERSION 2.6) +SET(THRNN_INSTALL_LIB_SUBDIR "${Torch_INSTALL_LUA_CPATH_SUBDIR}") +SET(THRNN_INSTALL_INCLUDE_SUBDIR "${Torch_INSTALL_INCLUDE_SUBDIR}") +ADD_SUBDIRECTORY(THRNN) diff --git a/lib/THRNN/CMakeLists.txt b/lib/THRNN/CMakeLists.txt new file mode 100644 index 0000000..04b1b71 --- /dev/null +++ b/lib/THRNN/CMakeLists.txt @@ -0,0 +1,83 @@ +CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR) +CMAKE_POLICY(VERSION 2.6) + +IF(NOT Torch_FOUND) + FIND_PACKAGE(Torch REQUIRED) +ENDIF() + +IF(NOT TH_LIBRARIES) + SET(TH_LIBRARIES "TH") +ENDIF(NOT TH_LIBRARIES) +MESSAGE(STATUS "TH_LIBRARIES: ${TH_LIBRARIES}") + +IF(NOT THRNN_INSTALL_LIB_SUBDIR) + SET(THRNN_INSTALL_LIB_SUBDIR "lib" CACHE PATH "THRNN install library directory") +ENDIF() + +# Flags +# When using MSVC +IF(MSVC) + # we want to respect the standard, and we are bored of those **** . + ADD_DEFINITIONS(-D_CRT_SECURE_NO_DEPRECATE=1) + ADD_DEFINITIONS(-DTH_EXPORTS) +ENDIF(MSVC) + +IF (CMAKE_VERSION VERSION_LESS "3.1") + SET(CMAKE_C_FLAGS "-std=c99 ${CMAKE_C_FLAGS}") +ELSE () + SET(CMAKE_C_STANDARD 99) +ENDIF () + +# OpenMP support? +SET(WITH_OPENMP ON CACHE BOOL "OpenMP support if available?") +IF (APPLE AND CMAKE_COMPILER_IS_GNUCC) + EXEC_PROGRAM (uname ARGS -v OUTPUT_VARIABLE DARWIN_VERSION) + STRING (REGEX MATCH "[0-9]+" DARWIN_VERSION ${DARWIN_VERSION}) + MESSAGE (STATUS "MAC OS Darwin Version: ${DARWIN_VERSION}") + IF (DARWIN_VERSION GREATER 9) + SET(APPLE_OPENMP_SUCKS 1) + ENDIF (DARWIN_VERSION GREATER 9) + EXECUTE_PROCESS (COMMAND ${CMAKE_C_COMPILER} -dumpversion + OUTPUT_VARIABLE GCC_VERSION) + IF (APPLE_OPENMP_SUCKS AND GCC_VERSION VERSION_LESS 4.6.2) + MESSAGE(STATUS "Warning: Disabling OpenMP (unstable with this version of GCC)") + MESSAGE(STATUS " Install GCC >= 4.6.2 or change your OS to enable OpenMP") + SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-unknown-pragmas") + SET(WITH_OPENMP OFF CACHE BOOL "OpenMP support if available?" FORCE) + ENDIF () +ENDIF () + +IF (WITH_OPENMP) + FIND_PACKAGE(OpenMP) + IF(OPENMP_FOUND) + MESSAGE(STATUS "Compiling with OpenMP support") + SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") + SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") + SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") + ENDIF(OPENMP_FOUND) +ENDIF (WITH_OPENMP) + +LINK_DIRECTORIES("${Torch_INSTALL_LIB}") + +SET(src init.c) +ADD_TORCH_LIBRARY(THRNN MODULE ${src}) +ADD_TORCH_LIBRARY(THRNN_static STATIC ${src}) +SET_TARGET_PROPERTIES(THRNN_static PROPERTIES COMPILE_FLAGS "-fPIC -DSTATIC_TH") + +INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR}) +### Torch packages supposes libraries prefix is "lib" +SET_TARGET_PROPERTIES(THRNN PROPERTIES + PREFIX "lib" + IMPORT_PREFIX "lib") + +TARGET_LINK_LIBRARIES(THRNN ${TH_LIBRARIES}) + +# Luarocks bug pre-14.04 prevents us from setting it for Lua-Torch +IF(THRNN_SO_VERSION) + MESSAGE(STATUS "THRNN_SO_VERSION: ${THRNN_SO_VERSION}") + SET_TARGET_PROPERTIES(THRNN PROPERTIES + VERSION ${THRNN_SO_VERSION} + SOVERSION ${THRNN_SO_VERSION}) +ENDIF(THRNN_SO_VERSION) + +INSTALL(TARGETS THRNN LIBRARY DESTINATION ${THRNN_INSTALL_LIB_SUBDIR}) diff --git a/lib/THRNN/THRNN.h b/lib/THRNN/THRNN.h new file mode 100644 index 0000000..607d38a --- /dev/null +++ b/lib/THRNN/THRNN.h @@ -0,0 +1,27 @@ +#ifndef THRNN_H +#define THRNN_H + +#include +#include +#ifdef _OPENMP +#include +#endif + +#define THRNN_(NAME) TH_CONCAT_3(THRNN_, Real, NAME) + +typedef long THIndex_t; +typedef int THInteger_t; +typedef void THRNNState; + +#define THRNN_resizeAs_indices(I1, I2) \ + THLongStorage *size2 = THIndexTensor_(newSizeOf)(I2); \ + if (!THTensor_(isSize)(I1, size2)) \ + { \ + THTensor_(resize)(I1, size2, NULL); \ + } \ + THLongStorage_free(size2); + +#include "generic/THRNN.h" +#include + +#endif diff --git a/lib/THRNN/generic/LSTM.c b/lib/THRNN/generic/LSTM.c new file mode 100644 index 0000000..557bc4e --- /dev/null +++ b/lib/THRNN/generic/LSTM.c @@ -0,0 +1,944 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/LSTM.c" +#else + +#ifdef _OPENMP +#include +#endif + +/* Set tensor->size[0] MACRO */ +#ifndef THRNN_LSTM_SET_SIZE +#define THRNN_LSTM_SET_SIZE(t, dim, newSize) ( t->size[dim] = newSize ) +#endif + +/* Set tensor->size[0] MACRO */ +#ifndef THRNN_LSTM_SET_STRIDE +#define THRNN_LSTM_SET_STRIDE(t, dim, newStride) ( t->stride[dim] = newStride ) +#endif + +/* Increment storageOffset */ +#ifndef THRNN_LSTM_INCREMENT_SOFFSET +#define THRNN_LSTM_INCREMENT_SOFFSET(t, dim, incrementSize) ( t->storageOffset += incrementSize*t->stride[dim] ) +#endif + +/* Decrement storageOffset */ +#ifndef THRNN_LSTM_DECREMENT_SOFFSET +#define THRNN_LSTM_DECREMENT_SOFFSET(t, dim, decrementSize) ( t->storageOffset -= decrementSize*t->stride[dim] ) +#endif + +// structure to hold a bunch of tensors +// and delete them all together, +// increment their offset altogether, +// set their sizes/strides altogether, ... +struct THRNN_(buffer) { + THTensor** array; + THTensor* buf; + int* sizes; + int len; + int total_size; + int ndim; + THLongStorage* default_buffer_sizes; + /* declare as many members as desired, but the entire structure size must be known to the compiler. */ +}; + +// create the buffer from an exisiting tensor (possibly NULL), +// and a LongStorage of default sizes for new individual buffers +struct THRNN_(buffer)* THRNN_(create_buffer)(THTensor* buf, THLongStorage* default_buffer_sizes) +{ + THTensor** arr; + struct THRNN_(buffer)* x = malloc( sizeof( struct THRNN_(buffer) )); + x->array = NULL; + x->buf = buf? buf:THTensor_(new)(); + x->sizes = NULL; + x->len = 0; + x->total_size = 0; + x->ndim = default_buffer_sizes->size; + x->default_buffer_sizes = default_buffer_sizes; + return x; +}; + +// This will add a tensor to the buffer, increasing the total size by +// 'size' times the product of sizes of all the other dimensions +THTensor* THRNN_(add_tensor_to_buffer)(struct THRNN_(buffer)* buffer, int size) +{ + buffer->len += 1; + buffer->total_size += size; + buffer->sizes = (int*)realloc(buffer->sizes, buffer->len * sizeof(int)); + buffer->sizes[buffer->len-1] = size; + buffer->array = (THTensor**)realloc(buffer->array, buffer->len * sizeof(THTensor**)); + THTensor* new_guy = THTensor_(new)(); + buffer->array[buffer->len-1] = new_guy; + return new_guy; +}; + +// Once the user has added all the tensors he needs to the buffer, +// he can 'compile' it, i.e. allocate one giant buffer that will be sliced +void THRNN_(compile_buffer)(struct THRNN_(buffer)* buffer) +{ + THLongStorage* bufferSize = THLongStorage_newWithSize(buffer->ndim); + int lastDimSize = 0; + int i; + for (i = 0; i < buffer->len; i++) + { + lastDimSize += buffer->sizes[i]; + } + for (i = 0; i < buffer->ndim-1; i++) + { + bufferSize->data[i] = buffer->default_buffer_sizes->data[i]; + THArgCheck(bufferSize->data[i] > 0, 1, "The buffer sizes must be > 0."); + } + bufferSize->data[buffer->ndim-1] = lastDimSize; + + THTensor_(resize)(buffer->buf, bufferSize, NULL); + THLongStorage_free(bufferSize); + int runningIdx = 0; + for (i = 0; i < buffer->len; i++) + { + if (buffer->sizes[i]) + { + THTensor_(narrow)(buffer->array[i], buffer->buf, buffer->ndim-1, runningIdx, buffer->sizes[i]); + } + runningIdx += buffer->sizes[i]; + } + return; +}; + +// Delete the buffer and all its components +void THRNN_(delete_buffer)(struct THRNN_(buffer)* buffer, int delete_internal_buffer) +{ + int i; + for (i = 0; i < buffer->len; i++) + { + THTensor_(free)(buffer->array[i]); + } + free(buffer->sizes); + free(buffer->array); + if (delete_internal_buffer) + THTensor_(free)(buffer->buf); + THLongStorage_free(buffer->default_buffer_sizes); + free(buffer); + return; +}; + +// batch set size +void THRNN_(buffer_set_size)(struct THRNN_(buffer)* buffer, int dim, int newSize) +{ + int i; + for (i = 0; i < buffer->len; i++) + { + THRNN_LSTM_SET_SIZE(buffer->array[i], dim, newSize); + } +} + +// batch set stride +void THRNN_(buffer_set_stride)(struct THRNN_(buffer)* buffer, int dim, int newStride) +{ + int i; + for (i = 0; i < buffer->len; i++) + { + THRNN_LSTM_SET_STRIDE(buffer->array[i], dim, newStride); + } +} + +// batch increment offset +void THRNN_(buffer_increment_soffset)(struct THRNN_(buffer)* buffer, int dim, int offset) +{ + int i; + for (i = 0; i < buffer->len; i++) + { + THRNN_LSTM_INCREMENT_SOFFSET(buffer->array[i], dim, offset); + } +} + +// batch decrement offset +void THRNN_(buffer_decrement_soffset)(struct THRNN_(buffer)* buffer, int dim, int offset) +{ + int i; + for (i = 0; i < buffer->len; i++) + { + THRNN_LSTM_DECREMENT_SOFFSET(buffer->array[i], dim, offset); + } +} + +// Convenient function to print tensors +static void THRNN_(printTensor)(char* name, THTensor *input) +{ + printf("Tensor %s\n", name); + + long ndim = THTensor_(nDimension)(input); + int i; + printf("- ndim %lu\n", ndim); + for (i = 0; i < ndim; i++) + { + printf("\tdimension #%i:\n", i); + int sz = THTensor_(size)(input, i); + int st = THTensor_(stride)(input, i); + printf("\t- size: %i\n", sz); + printf("\t- stride: %i\n", st); + } + printf("\n"); +} + +// Convenient function to print the mean of a tensor +static void THRNN_(printMean)(char* name, THTensor *input) +{ + printf("Tensor %s\n", name); + long nelements = THTensor_(nElement)(input); + long ndim = THTensor_(nDimension)(input); + accreal mean = THTensor_(meanall)(input);; + int i; + printf("\t- nelements %lu\n", nelements); + printf("\t- mean: %g\n", mean); +} + +#ifndef THRNN_LSTM_BASIC_LSTM_CELL +#define THRNN_LSTM_BASIC_LSTM_CELL 1 +#endif + +// Take a array of tensors, each tensor +// representing the inputs over time, or +// the last input in time if 'last' is == 1 +// Here: +// input --> concatenation of these tensors over the first (batch) dimension +// sizes --> sizes of tensors +// This function computes: +// output --> transposed time-first of input, +// i.e. for each time step, the elements are sequentially +// aligned if they are from the first tensor, the second, ... +// where first, second, ... are the sorted tensors by decreasing size +// sorted_by_batch_sizes --> the sizes sorted by batch size in decreasing order +// mapping --> the mapping from the output indices to the input indices. Useful for the inverse operation +void THRNN_(LSTM_bt_to_sorted_tb)( + THRNNState *state, + THTensor *input, + THLongTensor *sizes, + THTensor *output, + THLongTensor *mapping, + THLongTensor *sorted_by_batch_sizes, + THLongTensor *sorted_by_time_sizes, + THLongTensor *sorted_by_time_indices, + int last) +{ + long sizes_dim = THLongTensor_nDimension(sizes); + long input_dim = THTensor_(nDimension)(input); + THArgCheck(sizes_dim == 1, 2, "sizes must have 1 dimension"); + + long input_size_0 = THTensor_(size)(input, 0); + long nelements = THTensor_(nElement)(input); + long nelements_left = nelements / input_size_0; + long bsize = THLongTensor_size(sizes, 0); + + THLongTensor* sizes_copy = THLongTensor_new(); + THLongTensor* cumsum_sizes = THLongTensor_new(); + THLongTensor_sort(sorted_by_time_sizes, sorted_by_time_indices, sizes, 0, 1); + THLongTensor_resize1d(sizes_copy, bsize); + THLongTensor_cumsum(cumsum_sizes, sizes, 0); + THLongTensor_copy(sizes_copy, sizes); + + + long* sizes_data = THLongTensor_data(sizes); + long* sizes_copy_data = THLongTensor_data(sizes_copy); + long* cumsum_sizes_data = THLongTensor_data(cumsum_sizes); + long* sorted_by_time_sizes_data = THLongTensor_data(sorted_by_time_sizes); + long* sorted_by_time_indices_data = THLongTensor_data(sorted_by_time_indices); + real* input_data = THTensor_(data)(input); + + THArgCheck(THLongTensor_isContiguous(sizes), 2, "sizes vector must be contiguous"); + + long total_size = last?0:input_size_0; + long i,b; + if (last) + { + for (b = 0; b < bsize; b++) + { + total_size += sizes_data[b]; + } + } + + // Resize the output + long* new_size = malloc(input_dim*sizeof(long)); + long lnelements_left = nelements; + + for (i = 0;i < input_dim; i++) + { + new_size[i] = i?THTensor_(size)(input, i):total_size; + } + + THTensor_(resizeNd)(output, input_dim, new_size, NULL); + real* output_data = THTensor_(data)(output); + free(new_size); + + // Make sure these inputs are contiguous to accelerate computations + THArgCheck(THTensor_(isContiguous)(input), 1, "input vector must be contiguous"); + THArgCheck(THTensor_(isContiguous)(output), 3, "output vector must be contiguous"); + THArgCheck(THLongTensor_isContiguous(mapping), 4, "mapping vector must be contiguous"); + THArgCheck(THLongTensor_isContiguous(sorted_by_batch_sizes), 5, "sorted_by_batch_sizes must be contiguous."); + THArgCheck(THLongTensor_isContiguous(sorted_by_time_sizes), 6, "sorted_by_time_sizes must be contiguous."); + THArgCheck(THLongTensor_isContiguous(sorted_by_time_indices), 7, "sorted_by_time_indices must be contiguous."); + + long idx = 0; + THLongTensor_resize1d(mapping, total_size); + THLongTensor_resize1d(sorted_by_batch_sizes, sorted_by_time_sizes_data[0]); + THLongTensor_zero(sorted_by_batch_sizes); + long* mapping_data = THLongTensor_data(mapping); + long* sorted_by_batch_sizes_data = THLongTensor_data(sorted_by_batch_sizes); + + // if last, then the input only has end-of-sequence values defined + // for each element of the batch + if (last) + { + THTensor_(zero)(output); +// THRNN_(printTensor)("output", output); +// THRNN_(printTensor)("input", input); + for (i = 0; i < sorted_by_time_sizes_data[0]; i++) + { + for (b = 0; b < bsize; b++) + { + long sidx = sorted_by_time_indices_data[b]; + long timesteps = sizes_copy_data[sidx]; + if (timesteps > 0) + { + long lidx = ((sidx>0?cumsum_sizes_data[sidx-1]:0) + sizes_data[sidx] - sizes_copy_data[sidx]); + if (timesteps == 1) + { + memcpy(output_data + idx*nelements_left, input_data + sidx*nelements_left, nelements_left*sizeof(real)); + } + mapping_data[lidx] = idx; + sizes_copy_data[sidx]--; + sorted_by_batch_sizes_data[i]++; + idx++; + } + } + } + } + else + { + for (i = 0; i < sorted_by_time_sizes_data[0]; i++) + { + for (b = 0; b < bsize; b++) + { + long sidx = sorted_by_time_indices_data[b]; + long timesteps = sizes_copy_data[sidx]; + if (timesteps > 0) + { + long lidx = ((sidx>0?cumsum_sizes_data[sidx-1]:0) + sizes_data[sidx] - sizes_copy_data[sidx]); + memcpy(output_data + idx*nelements_left, input_data + lidx*nelements_left, nelements_left*sizeof(real)); + mapping_data[lidx] = idx; + sizes_copy_data[sidx]--; + sorted_by_batch_sizes_data[i]++; + idx++; + } + } + } + } + + THLongTensor_free(sizes_copy); + THLongTensor_free(cumsum_sizes); +} + + +// Reverse operation from LSTM_bt_to_sorted_tb +void THRNN_(LSTM_sorted_tb_to_bt)( + THRNNState *state, + THTensor *input, + THLongTensor *sizes, + THLongTensor *mapping, + THTensor *output, + int last) +{ + long input_dim = THTensor_(nDimension)(input); + + long input_size_0 = THTensor_(size)(input, 0); + long nelements = THTensor_(nElement)(input); + long nelements_left = nelements / input_size_0; + long batch_size = THLongTensor_size(sizes, 0); + + // Resize the output + long* new_size = malloc(input_dim*sizeof(long)); + long lnelements_left = nelements; + long total_size = last?batch_size:input_size_0; + long i,b; + + for (i = 0;i < input_dim; i++) + { + new_size[i] = i?THTensor_(size)(input, i):total_size; + } + + THTensor_(resizeNd)(output, input_dim, new_size, NULL); + real* output_data = THTensor_(data)(output); + free(new_size); + + long* mapping_data = THLongTensor_data(mapping); + real* input_data = THTensor_(data)(input); + long* sizes_data = THLongTensor_data(sizes); + + // Make sure these inputs are contiguous to accelerate computations + THArgCheck(THTensor_(isContiguous)(input), 1, "input vector must be contiguous"); + THArgCheck(THLongTensor_isContiguous(mapping), 3, "sizes vector must be contiguous"); + THArgCheck(THTensor_(isContiguous)(output), 4, "output vector must be contiguous"); + THArgCheck(THLongTensor_isContiguous(sizes), 2, "sizes vector must be contiguous"); + + + if (last) + { + THLongTensor* cumsum_sizes = THLongTensor_new(); + THLongTensor_cumsum(cumsum_sizes, sizes, 0); + long* cumsum_sizes_data = THLongTensor_data(cumsum_sizes); + for (i = 0; i < batch_size; i++) + { + long idx = cumsum_sizes_data[i]-1; + memcpy(output_data + i*nelements_left, + input_data + mapping_data[idx]*nelements_left, + nelements_left*sizeof(real)); + } + THLongTensor_free(cumsum_sizes); + } + else + { + for (i = 0; i < input_size_0; i++) + { + memcpy(output_data + i*nelements_left, + input_data + mapping_data[i]*nelements_left, + nelements_left*sizeof(real)); + } + } +} + +/* + * That macro is used in LSTM_updateOutput + * and LSTM_backward, hence the generalization to a macro. + * create buffers + slices as needed + views as needed for the buffer of the forward pass. + */ +#ifndef THRNN_LSTM_FORWARD_BUFFERS +#define THRNN_LSTM_FORWARD_BUFFERS() \ + THLongStorage* default_buffer_sizes = THLongStorage_newWithSize(1); \ + default_buffer_sizes->data[0] = nEmbeddings; \ + struct THRNN_(buffer)* rnn_buffer = THRNN_(create_buffer)(buffer, default_buffer_sizes); \ +\ + THTensor* ifog_hat_t = THRNN_(add_tensor_to_buffer)(rnn_buffer, 4 * nEmbeddings * outputFeatures); \ + THTensor* ifo_t = THRNN_(add_tensor_to_buffer)(rnn_buffer, 3 * nEmbeddings * outputFeatures); \ + THTensor* g_t = THRNN_(add_tensor_to_buffer)(rnn_buffer, nEmbeddings * outputFeatures); \ + THTensor* c_buffer_t = THRNN_(add_tensor_to_buffer)(rnn_buffer, nEmbeddings * outputFeatures); \ + THTensor* c_buffer2_t = THRNN_(add_tensor_to_buffer)(rnn_buffer, nEmbeddings * outputFeatures); \ + THTensor* c_t = THRNN_(add_tensor_to_buffer)(rnn_buffer, nEmbeddings * outputFeatures); \ + THTensor* tanh_c_t = THRNN_(add_tensor_to_buffer)(rnn_buffer, nEmbeddings * outputFeatures); \ +\ + THTensor* ifo_hat_t = THRNN_(add_tensor_to_buffer)(rnn_buffer, 0); \ + THTensor* g_hat_t = THRNN_(add_tensor_to_buffer)(rnn_buffer, 0); \ + THTensor* i_t = THRNN_(add_tensor_to_buffer)(rnn_buffer, 0); \ + THTensor* f_t = THRNN_(add_tensor_to_buffer)(rnn_buffer, 0); \ + THTensor* o_t = THRNN_(add_tensor_to_buffer)(rnn_buffer, 0); \ +\ + THRNN_(compile_buffer)(rnn_buffer); \ +\ + THTensor_(resize2d)(ifog_hat_t, nEmbeddings, 4*outputFeatures); \ + THTensor_(resize2d)(ifo_t, nEmbeddings, 3*outputFeatures); \ + THTensor_(resize2d)(g_t, nEmbeddings, outputFeatures); \ + THTensor_(resize2d)(c_buffer_t, nEmbeddings, outputFeatures); \ + THTensor_(resize2d)(c_buffer2_t, nEmbeddings, outputFeatures); \ + THTensor_(resize2d)(c_t, nEmbeddings, outputFeatures); \ + THTensor_(resize2d)(tanh_c_t, nEmbeddings, outputFeatures); \ + THTensor_(narrow)(ifo_hat_t, ifog_hat_t, 1, 0, 3*outputFeatures); \ + THTensor_(narrow)(g_hat_t, ifog_hat_t, 1, 3*outputFeatures, outputFeatures); \ + THTensor_(narrow)(i_t, ifo_t, 1, 0, outputFeatures); \ + THTensor_(narrow)(f_t, ifo_t, 1, outputFeatures, outputFeatures); \ + THTensor_(narrow)(o_t, ifo_t, 1, outputFeatures*2, outputFeatures); +#endif + +#ifndef THRNN_LSTM_GRU_CELL +#define THRNN_LSTM_GRU_CELL 2 +#endif + +// input is organized time-first, concatenated, +// in decreasing order of batch size per time step. +// for example, if the batch of sequences is: +// { Tensor(2,100), Tensor(5,100), Tensor(4,100) }, +// then the data has to be transposed and sorted and will look like: +// input: torch.Tensor(5+4+2) +// sizes: torch.LongTensor({5,4,2}) +void THRNN_(LSTM_updateOutput)( + THRNNState *state, + THTensor *input, + THTensor *inputC, + THTensor *inputH, + THLongTensor *sizes, + THTensor *output, + THTensor *weight, + THTensor *bias, + THTensor *buffer, + int inputFeatures, + int outputFeatures, + int cellType, + int train, + int delete_buffer) +{ + long sizesDim = THLongTensor_nDimension(sizes); + long inputDim = THTensor_(nDimension)(input); + THArgCheck(sizesDim == 1, 2, "sizes must have 1 dimension"); + THArgCheck(inputDim == 2, 1, "input must have 2 dimensions"); + + // Retrieve all the dimensions of the problem + long totalTimeSteps = THLongTensor_size(sizes, 0); + long nEmbeddings = THTensor_(size)(input, 0); + long inputSize1 = THTensor_(size)(input, 1); + long weightSize0 = THTensor_(size)(weight, 0); + long weightSize1 = THTensor_(size)(weight, 1); + long* sizesData = THLongTensor_data(sizes); + + // Resize the output + THTensor_(resize2d)(output, nEmbeddings, outputFeatures); + + + // Compute the sum of sizes to + // further check they're equal to the number of embeddings + THLongTensor* sizesSum = THLongTensor_new(); + THLongTensor_sum(sizesSum, sizes, 0, 1); + long ssum = sizesSum->storage->data[0]; + THLongTensor_free(sizesSum); + + // Make sure these inputs are contiguous to accelerate computations + THArgCheck(THTensor_(isContiguous)(input), 1, "input vector must be contiguous"); + THArgCheck(THTensor_(isContiguous)(inputH), 6, "inputH vector must be contiguous"); + THArgCheck(THTensor_(isContiguous)(inputC), 5, "inputC vector must be contiguous"); + THArgCheck(THTensor_(isContiguous)(output), 3, "output vector must be contiguous"); + THArgCheck(THTensor_(isContiguous)(weight), 4, "weight matrix must be contiguous"); + THArgCheck(THTensor_(isContiguous)(bias), 5, "bias vector must be contiguous"); + THArgCheck(THTensor_(isContiguous)(buffer), 8, "buffer tensor must be contiguous"); + THArgCheck(ssum == nEmbeddings, 9, "the sum of sizes must be equal to the number of input embeddings"); + long t; + + if (cellType == THRNN_LSTM_BASIC_LSTM_CELL) + { + + /* + * Definitions: + * w == weights for the fully-connected layer. + * Weights are concatenated for [input_gate,forget_gate,output_gate,input_tansformation], in this exact order. + * b == expanded bias + * i_t == input gate at timestep t + * f_t == forget gate at timestep t + * o_t == output gate at timestep t + * g_t == input transformation at timestep t + * ifog_hat_t == concatenation of i_t,f_t, o_t and g_t before applying the non-linearities + * ifo_t == concatenation of i_t,f_t and o_t + * input_t == input at timestep t + * w_x == weights for the input part + * w_h == weights for the hidden part + * + * Sequence of computations: + * 1/ matrix-multiply input for all timesteps by weights for input + * ifog_hat_{all} = w_x * input_{all} + b + * + * 2/ for each timestep t do: + * ifog_hat_t += w_h * h_t + b + * ifo_t = sigmoid(ifo_hat_t) + * g_t = tanh(g_hat_t) + * c_buffer_t = f_t o c_{t-1} + * c_buffer2_t = i_t o g_t + * c_t = c_buffer_t + c_buffer2_t + * tanh_c_t = tanh(c_t) + * output_t = o_t o tanh_c_t + */ + + long totalWeightColumns = outputFeatures*4; + THAssert(totalWeightColumns == weightSize1); + + // Alright, this will create ALL the buffers we need for the forward pass + // Allocating a big contiguous dude will help us slice intelligently + // when performing our dear computations. + // This will also keep track of stuff we need to store for the backward pass. + THRNN_LSTM_FORWARD_BUFFERS() + + // Initialize a few extra "non-temporal" buffers + THLongStorage* weight_buffer_sizes = THLongStorage_newWithSize(1); + weight_buffer_sizes->data[0] = 1; + struct THRNN_(buffer)* weight_buffer = THRNN_(create_buffer)(NULL, weight_buffer_sizes); + THTensor* w_h = THRNN_(add_tensor_to_buffer)(weight_buffer, 0); + THTensor* w_x = THRNN_(add_tensor_to_buffer)(weight_buffer, 0); + THTensor* c_tless = THRNN_(add_tensor_to_buffer)(weight_buffer, 0); + THTensor* h_tless = THRNN_(add_tensor_to_buffer)(weight_buffer, 0); + + THRNN_(compile_buffer)(weight_buffer); + + // Slice weights for hidden nodes and inputs + THTensor_(narrow)(w_h, weight, 0, inputFeatures, outputFeatures); + THTensor_(narrow)(w_x, weight, 0, 0, inputFeatures); + + // Resize/expand the bias + THTensor_(resize2d)(bias, 1, THTensor_(size)(bias, 0)); + THRNN_LSTM_SET_STRIDE(bias, 0, 0); + THRNN_LSTM_SET_SIZE(bias, 0, nEmbeddings); + + // Ok now start computations + // First, pre-multiply the input with the weights + // Pre-multiply the input by the weight + THTensor_(addmm)(ifog_hat_t, 1, bias, 1, input, w_x); + + for (t = 0; t < totalTimeSteps; t++) + { + long bsize = sizesData[t]; + + // Narrow the buffers along the first dimension + // to be equal to the batch size for this particular timestep + // We're not using THTensor_(narrow) for efficiency. + THRNN_(buffer_set_size)(rnn_buffer, 0, bsize); + THRNN_LSTM_SET_SIZE(output, 0, bsize); + THRNN_LSTM_SET_SIZE(bias, 0, bsize); + + if (!t) + { + THTensor_(narrow)(c_tless, inputC, 0, 0, sizesData[t]); + THTensor_(narrow)(h_tless, inputH, 0, 0, sizesData[t]); + } + // Add hidden values after mat-mul + // to the existing input values after matmul (obtained from the pre-multiplication) + THTensor_(addmm)(ifog_hat_t, 1, ifog_hat_t, 1, h_tless, w_h); + + // Sigmoidize the first 3 slices + // Non-contiguous operation but heh + THTensor_(sigmoid)(ifo_t, ifo_hat_t); + + // Tanhize the last slice + // Non-contiguous operation but heh + THTensor_(tanh)(g_t, g_hat_t); + + // apply the forget gate + // Contiguous operation + THTensor_(cmul)(c_buffer_t, f_t, c_tless); + + // apply the input gate + // Contiguous operation + THTensor_(cmul)(c_buffer2_t, i_t, g_t); + + // add the residual cell value to the previous cell + // and update the cell state + // Contiguous operation + THTensor_(cadd)(c_t, c_buffer_t, 1, c_buffer2_t); + + // apply the tanh to the output cell + // Contiguous operations + THTensor_(tanh)(tanh_c_t, c_t); + + // apply the output gate + // and update the hidden state + // Contiguous operations + THTensor_(cmul)(output, tanh_c_t, o_t); + + if (t < totalTimeSteps-1) + { + THTensor_(narrow)(c_tless, c_t, 0, 0, sizesData[t+1]); + THTensor_(narrow)(h_tless, output, 0, 0, sizesData[t+1]); + } + + // After the computations for that timestep are done, + // increment the offset to switch to the next timestep + THRNN_(buffer_increment_soffset)(rnn_buffer, 0, bsize); + THRNN_LSTM_INCREMENT_SOFFSET(output, 0, bsize); + } + // Get all the buffers back to their original states + THRNN_(buffer_decrement_soffset)(rnn_buffer, 0, nEmbeddings); + THRNN_(buffer_set_size)(rnn_buffer, 0, nEmbeddings); + THRNN_LSTM_DECREMENT_SOFFSET(output, 0, nEmbeddings); + + // Reshape the bias to its original shape + THRNN_LSTM_SET_SIZE(bias, 0, 1); + THTensor_(resize1d)(bias, THTensor_(size)(bias, 1)); + THRNN_LSTM_SET_STRIDE(bias, 0, 1); + + // Reshape the hidden and cell values to the max batch size + // resize the output back to its total size + THRNN_LSTM_SET_SIZE(output, 0, nEmbeddings); + + THRNN_(delete_buffer)(rnn_buffer, train?0:delete_buffer); + THRNN_(delete_buffer)(weight_buffer, 0); + + } + return; +} + + + + +void THRNN_(LSTM_backward)( + THRNNState *state, + THTensor *input, + THTensor *inputC, + THTensor *inputH, + THLongTensor *sizes, + THTensor *gradOutput, + THTensor *gradInput, + THTensor *gradInputH, + THTensor *gradInputC, + THTensor *weight, + THTensor *bias, + THTensor *buffer, + THTensor *weightBuffer, + THTensor *gradInputBuffer, + THTensor *gradWeight, + THTensor *gradBias, + THTensor *output, + real scale, + int last, + int inputFeatures, + int outputFeatures, + int cellType, + int delete_buffer) +{ + long sizesDim = THLongTensor_nDimension(sizes); + long inputDim = THTensor_(nDimension)(input); + THArgCheck(sizesDim == 1, 2, "sizes must have 1 dimension"); + THArgCheck(inputDim == 2, 1, "input must have 2 dimensions"); + // Retrieve all the dimensions of the problem + long totalTimeSteps = THLongTensor_size(sizes, 0); + long nEmbeddings = THTensor_(size)(input, 0); + long inputSize1 = THTensor_(size)(input, 1); + long weightSize0 = THTensor_(size)(weight, 0); + long weightSize1 = THTensor_(size)(weight, 1); + long* sizesData = THLongTensor_data(sizes); + + // Resize the output + THTensor_(resize2d)(gradInput, nEmbeddings, inputSize1); + THTensor_(resize2d)(gradWeight, weightSize0, weightSize1); + THTensor_(resize1d)(gradBias, 4*outputFeatures); + + + THLongTensor* sizesSum = THLongTensor_new(); + THLongTensor_sum(sizesSum, sizes, 0, 1); + long ssum = sizesSum->storage->data[0]; + THLongTensor_free(sizesSum); + + // Make sure these inputs are contiguous to accelerate computations + THArgCheck(THTensor_(isContiguous)(input), 1, "input vector must be contiguous"); + THArgCheck(THTensor_(isContiguous)(inputH), 6, "inputH vector must be contiguous"); + THArgCheck(THTensor_(isContiguous)(inputC), 5, "inputC vector must be contiguous"); + THArgCheck(THTensor_(isContiguous)(output), 3, "output vector must be contiguous"); + THArgCheck(THTensor_(isContiguous)(weight), 4, "weight matrix must be contiguous"); + THArgCheck(THTensor_(isContiguous)(bias), 5, "bias vector must be contiguous"); + THArgCheck(THTensor_(isContiguous)(buffer), 8, "buffer tensor must be contiguous"); + THArgCheck(ssum == nEmbeddings, 9, "the sum of sizes must be equal to the number of input embeddings"); + THArgCheck(THTensor_(nDimension)(gradOutput) == THTensor_(nDimension)(output), 3, "output and gradOutput do not have the same # of dimensions"); + THArgCheck(THTensor_(size)(gradOutput, 1) == THTensor_(size)(output, 1), 3, "when the output type is 'last', gradOutput and output must have the same size for the 2nd dimension"); + + long t; + + if (cellType == THRNN_LSTM_BASIC_LSTM_CELL) + { + // Resize the buffer + long totalWeightColumns = outputFeatures*4; + THAssert(totalWeightColumns == weightSize1); + + // Alright, this buffer is going to contain ALL the buffers we need. + // This will also keep track of stuff we need to store for the backward pass. + THRNN_LSTM_FORWARD_BUFFERS() + + THLongStorage* gi_default_buffer_sizes = THLongStorage_newWithSize(1); + gi_default_buffer_sizes->data[0] = nEmbeddings; + struct THRNN_(buffer)* gi_rnn_buffer = THRNN_(create_buffer)(gradInputBuffer, gi_default_buffer_sizes); + + THTensor* difog_t = THRNN_(add_tensor_to_buffer)(gi_rnn_buffer, 4 * nEmbeddings * outputFeatures); + THTensor* dc_t = THRNN_(add_tensor_to_buffer)(gi_rnn_buffer, nEmbeddings * outputFeatures); + THTensor* dtanh_c_t = THRNN_(add_tensor_to_buffer)(gi_rnn_buffer, nEmbeddings * outputFeatures); + THTensor* ones = THRNN_(add_tensor_to_buffer)(gi_rnn_buffer, nEmbeddings * outputFeatures); + THTensor* x_t = THRNN_(add_tensor_to_buffer)(gi_rnn_buffer, nEmbeddings * inputFeatures); + THTensor* dh_t = THRNN_(add_tensor_to_buffer)(gi_rnn_buffer, nEmbeddings * outputFeatures); + + THTensor* difo_t = THRNN_(add_tensor_to_buffer)(gi_rnn_buffer, 0); + THTensor* di_t = THRNN_(add_tensor_to_buffer)(gi_rnn_buffer, 0); + THTensor* df_t = THRNN_(add_tensor_to_buffer)(gi_rnn_buffer, 0); + THTensor* do_t = THRNN_(add_tensor_to_buffer)(gi_rnn_buffer, 0); + THTensor* dg_t = THRNN_(add_tensor_to_buffer)(gi_rnn_buffer, 0); + + THRNN_(compile_buffer)(gi_rnn_buffer); + + THTensor_(resize2d)(difog_t, nEmbeddings, 4*outputFeatures); + THTensor_(resize2d)(dc_t, nEmbeddings, outputFeatures); + THTensor_(resize2d)(dtanh_c_t, nEmbeddings, outputFeatures); + THTensor_(resize2d)(ones, nEmbeddings, outputFeatures); + THTensor_(resize2d)(dh_t, nEmbeddings, outputFeatures); + + THTensor_(narrow)(difo_t, difog_t, 1, 0, 3*outputFeatures); + THTensor_(narrow)(di_t, difog_t, 1, 0, outputFeatures); + THTensor_(narrow)(df_t, difog_t, 1, outputFeatures, outputFeatures); + THTensor_(narrow)(do_t, difog_t, 1, 2*outputFeatures, outputFeatures); + THTensor_(narrow)(dg_t, difog_t, 1, 3*outputFeatures, outputFeatures); + THTensor_(fill)(ones, 1); + + + THLongStorage* weight_buffer_sizes = THLongStorage_newWithSize(1); + weight_buffer_sizes->data[0] = 1; + struct THRNN_(buffer)* weight_buffer = THRNN_(create_buffer)(weightBuffer, weight_buffer_sizes); + + THTensor* w_h = THRNN_(add_tensor_to_buffer)(weight_buffer, 0); + THTensor* w_x = THRNN_(add_tensor_to_buffer)(weight_buffer, 0); + THTensor* w_hT = THRNN_(add_tensor_to_buffer)(weight_buffer, 0); + THTensor* w_xT = THRNN_(add_tensor_to_buffer)(weight_buffer, 0); + THTensor* grad_weight_h = THRNN_(add_tensor_to_buffer)(weight_buffer, 0); + THTensor* grad_weight_x = THRNN_(add_tensor_to_buffer)(weight_buffer, 0); + THTensor* h_tlessT = THRNN_(add_tensor_to_buffer)(weight_buffer, 0); + THTensor* x_tT = THRNN_(add_tensor_to_buffer)(weight_buffer, 0); + THTensor* c_tless = THRNN_(add_tensor_to_buffer)(weight_buffer, 0); + THTensor* h_tless = THRNN_(add_tensor_to_buffer)(weight_buffer, 0); + THTensor* grad_bias_buffer = THRNN_(add_tensor_to_buffer)(weight_buffer, 4*outputFeatures); + THTensor* grad_input_c = THRNN_(add_tensor_to_buffer)(weight_buffer, sizesData[0] * outputFeatures); + THTensor* grad_input_h = THRNN_(add_tensor_to_buffer)(weight_buffer, sizesData[0] * outputFeatures); + + THRNN_(compile_buffer)(weight_buffer); + + THTensor_(narrow)(w_h, weight, 0, inputFeatures, outputFeatures); + THTensor_(narrow)(w_x, weight, 0, 0, inputFeatures); + THTensor_(narrow)(grad_weight_h, gradWeight, 0, inputFeatures, outputFeatures); + THTensor_(narrow)(grad_weight_x, gradWeight, 0, 0, inputFeatures); + THTensor_(transpose)(w_hT, w_h, 0, 1); + THTensor_(transpose)(w_xT, w_x, 0, 1); + THTensor_(resize2d)(grad_input_c, sizesData[0], outputFeatures); + THTensor_(resize2d)(grad_input_h, sizesData[0], outputFeatures); + THTensor_(resize2d)(grad_bias_buffer, 1, 4*outputFeatures); + + // Increment all the buffer up to the last timestep + THRNN_(buffer_increment_soffset)(rnn_buffer, 0, nEmbeddings); + THRNN_(buffer_increment_soffset)(gi_rnn_buffer, 0, nEmbeddings); + THRNN_LSTM_INCREMENT_SOFFSET(output, 0, nEmbeddings); + + // If the output type is not 'last', + // then it means that gradOutput has the same size as output. + // Otherwise it must be of size batchSize x outputSize + if (!last) + { + THArgCheck(THTensor_(size)(gradOutput, 0) == THTensor_(size)(output, 0), 3, "when the output type is not 'last', gradOutput and output must have the same size for the first dimension"); + THRNN_LSTM_INCREMENT_SOFFSET(gradOutput, 0, nEmbeddings); + } + else + { + THArgCheck(THTensor_(size)(gradOutput, 0) == sizesData[0], 3, "when the output type is 'last', the size of gradOutput for the first dimension must be equal to the batch size"); + } + + THTensor_(copy)(grad_input_c, gradInputC); + THTensor_(copy)(grad_input_h, gradInputH); + + for (t = totalTimeSteps-1; t >= 0; t--) + { + long bsize = sizesData[t]; + + THRNN_(buffer_decrement_soffset)(rnn_buffer, 0, bsize); + THRNN_(buffer_decrement_soffset)(gi_rnn_buffer, 0, bsize); + THRNN_LSTM_DECREMENT_SOFFSET(output, 0, bsize); + + // Narrow the buffers along the first dimension + // We're not using THTensor_(narrow) for efficiency. + THRNN_(buffer_set_size)(rnn_buffer, 0, bsize); + THRNN_(buffer_set_size)(gi_rnn_buffer, 0, bsize); + THRNN_LSTM_SET_SIZE(grad_input_c, 0, bsize); + THRNN_LSTM_SET_SIZE(grad_input_h, 0, bsize); + THRNN_LSTM_SET_SIZE(output, 0, bsize); + + if (!t) + { + THTensor_(narrow)(c_tless, inputC, 0, 0, sizesData[t]); + THTensor_(narrow)(h_tless, inputH, 0, 0, sizesData[t]); + } + else + { + THRNN_LSTM_DECREMENT_SOFFSET(output, 0, sizesData[t-1]); + THRNN_LSTM_DECREMENT_SOFFSET(c_t, 0, sizesData[t-1]); + THTensor_(narrow)(h_tless, output, 0, 0, sizesData[t]); + THTensor_(narrow)(c_tless, c_t, 0, 0, sizesData[t]); + THRNN_LSTM_INCREMENT_SOFFSET(output, 0, sizesData[t-1]); + THRNN_LSTM_INCREMENT_SOFFSET(c_t, 0, sizesData[t-1]); + } + // If the output type is not 'last', + // We need to decrement the offset for gradOutput as well, + // and set the size + if (!last) + { + THRNN_LSTM_DECREMENT_SOFFSET(gradOutput, 0, bsize); + THRNN_LSTM_SET_SIZE(gradOutput, 0, bsize); + } + // If we are at the last timestep, copy gradInputH and add gradOutput directly to dh_t, + // otherwise add it if the output type is not last. + // For all timesteps != last_timestep, + // accumulate the gradients from the next timestep + THTensor_(copy)(dh_t, grad_input_h); + if (t == totalTimeSteps-1 || !last) + { + THTensor_(cadd)(dh_t, dh_t, 1, gradOutput); + } + // Compute do_t = dh_t o tanh(c_t) + THTensor_(cmul)(do_t, dh_t, tanh_c_t); + + // Compute dc_t += (1-tanh^2(c_t)) o o_t o dh_t + THTensor_(cmul)(dtanh_c_t, tanh_c_t, tanh_c_t); + THTensor_(cadd)(dc_t, ones, -1, dtanh_c_t); + THTensor_(cmul)(dc_t, dc_t, dh_t); + THTensor_(cmul)(dc_t, dc_t, o_t); + THTensor_(cadd)(grad_input_c, grad_input_c, 1, dc_t); + + // Now compute di_t = dc_t o g_t + THTensor_(cmul)(di_t, grad_input_c, g_t); + + // Now compute df_t = dc_t o c(t-1) + THTensor_(cmul)(df_t, grad_input_c, c_tless); + + // Now compute dg_t = dc_t o i_t + THTensor_(cmul)(dg_t, grad_input_c, i_t); + THTensor_(cmul)(grad_input_c, grad_input_c, f_t); + + + // Compute di_t = di_t o i_t o (1-i_t) + // Compute df_t = df_t o f_t o (1-f_t) + // Compute do_t = df_t o f_t o (1-o_t) + // Compute dg_t = dg_t o (1-tanh^2(g_t)) + THTensor_(cadd)(dc_t, ones, -1, i_t); + THTensor_(cmul)(di_t, di_t, dc_t); + THTensor_(cmul)(di_t, di_t, i_t); + + THTensor_(cadd)(dc_t, ones, -1, f_t); + THTensor_(cmul)(df_t, df_t, dc_t); + THTensor_(cmul)(df_t, df_t, f_t); + + THTensor_(cadd)(dc_t, ones, -1, o_t); + THTensor_(cmul)(do_t, do_t, dc_t); + THTensor_(cmul)(do_t, do_t, o_t); + + THTensor_(cmul)(dc_t, g_t, g_t); + THTensor_(cadd)(dtanh_c_t, ones, -1, dc_t); + THTensor_(cmul)(dg_t, dg_t, dtanh_c_t); + + + THTensor_(transpose)(h_tlessT, h_tless, 0, 1); + + // Accumulate the gradient wrt to the bias + THTensor_(sum)(grad_bias_buffer, difog_t, 0, 1); + THTensor_(cadd)(gradBias, gradBias, scale, grad_bias_buffer); + + // Accumulate the gradient wrt to the weight, for hidden nodes only + THTensor_(addmm)(grad_weight_h, 1, grad_weight_h, scale, h_tlessT, difog_t); + + // Compute the gradient wrt the input + THTensor_(zero)(grad_input_h); + THTensor_(addmm)(grad_input_h, 0, grad_input_h, 1, difog_t, w_hT); + } + + THRNN_(buffer_set_size)(rnn_buffer, 0, nEmbeddings); + THRNN_(buffer_set_size)(gi_rnn_buffer, 0, nEmbeddings); + + THTensor_(transpose)(x_tT, input, 0, 1); + THTensor_(addmm)(grad_weight_x, 1, grad_weight_x, scale, x_tT, difog_t); + THTensor_(zero)(gradInput); + THTensor_(addmm)(gradInput, 0, gradInput, 1, difog_t, w_xT); + + THRNN_LSTM_SET_SIZE(grad_input_h, 0, sizesData[0]); + THRNN_LSTM_SET_SIZE(grad_input_c, 0, sizesData[0]); + THRNN_LSTM_SET_SIZE(output, 0, nEmbeddings); + THRNN_LSTM_SET_SIZE(c_tless, 0, sizesData[0]); + THRNN_LSTM_SET_SIZE(h_tless, 0, sizesData[0]); + if (!last) + { + THRNN_LSTM_SET_SIZE(gradOutput, 0, nEmbeddings); + } + THRNN_(delete_buffer)(rnn_buffer, delete_buffer); + THRNN_(delete_buffer)(gi_rnn_buffer, delete_buffer); + THRNN_(delete_buffer)(weight_buffer, delete_buffer); + } + return; +} +#endif diff --git a/lib/THRNN/generic/THRNN.h b/lib/THRNN/generic/THRNN.h new file mode 100644 index 0000000..11ac59b --- /dev/null +++ b/lib/THRNN/generic/THRNN.h @@ -0,0 +1,64 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/THRNN.h" +#else + +TH_API void THRNN_(LSTM_bt_to_sorted_tb)( + THRNNState *state, + THTensor *input, + THLongTensor *sizes, + THTensor *output, + THLongTensor *mapping, + THLongTensor *sorted_by_batch_sizes, + THLongTensor *sorted_by_time_sizes, + THLongTensor *sorted_by_time_indices, + int last); + +TH_API void THRNN_(LSTM_sorted_tb_to_bt)( + THRNNState *state, + THTensor *input, + THLongTensor *sizes, + THLongTensor *mapping, + THTensor *output, + int last); + +TH_API void THRNN_(LSTM_updateOutput)( + THRNNState *state, + THTensor *input, + THTensor *inputC, + THTensor *inputH, + THLongTensor *sizes, + THTensor *output, + THTensor *weight, + THTensor *bias, + THTensor *buffer, + int inputFeatures, + int outputFeatures, + int cellType, + int train, + int delete_buffer); + +TH_API void THRNN_(LSTM_backward)( + THRNNState *state, + THTensor *input, + THTensor *inputC, + THTensor *inputH, + THLongTensor *sizes, + THTensor *gradOutput, + THTensor *gradInput, + THTensor *gradInputH, + THTensor *gradInputC, + THTensor *weight, + THTensor *bias, + THTensor *buffer, + THTensor *weightBuffer, + THTensor *gradInputBuffer, + THTensor *gradWeight, + THTensor *gradBias, + THTensor *output, + real scale, + int last, + int inputFeatures, + int outputFeatures, + int cellType, + int delete_buffer); +#endif diff --git a/lib/THRNN/init.c b/lib/THRNN/init.c new file mode 100644 index 0000000..5853b83 --- /dev/null +++ b/lib/THRNN/init.c @@ -0,0 +1,4 @@ +#include "THRNN.h" + +#include "generic/LSTM.c" +#include "THGenerateFloatTypes.h" diff --git a/test/bigtest.lua b/test/bigtest.lua index 72fd913..3c27614 100644 --- a/test/bigtest.lua +++ b/test/bigtest.lua @@ -691,7 +691,9 @@ end function rnnbigtest.LSTM() local seqlen, batchsize = 30, 32 local inputsize, outputsize = 128, 128 - local nloop = 20 + local nloop = 100 + local ttype = torch.getdefaulttensortype() + torch.setdefaulttensortype('torch.FloatTensor') local lstms = {} lstms.fast = nn.Sequencer(nn.FastLSTM(inputsize, outputsize)) @@ -709,31 +711,45 @@ function rnnbigtest.LSTM() lstms.seq = nn.SeqLSTM(inputsize, outputsize) lstms.luaseq = nn.SeqLSTM(inputsize, outputsize) lstms.luaseq.forceLua = true + lstms.vseq = nn.VSeqLSTM(inputsize, outputsize) - local input = torch.Tensor(seqlen, batchsize, inputsize) - local gradOutput = torch.Tensor(seqlen, batchsize, outputsize) + local input = torch.Tensor(seqlen, batchsize, inputsize):uniform() + local gradOutput = torch.Tensor(seqlen, batchsize, outputsize):uniform() local t = torch.Timer() print("CPU test") for name, lstm in pairs(lstms) do + local linput,lgradOutput + if name == 'vseq' then + linput = {torch.Tensor(seqlen*batchsize,inputsize):uniform(), torch.LongTensor(seqlen):fill(batchsize)} + lgradOutput = torch.Tensor(seqlen*batchsize,outputsize):uniform() + else + linput = input + lgradOutput = gradOutput + end -- warmup lstm:remember('neither') - lstm:forward(input) - lstm:zeroGradParameters() - lstm:backward(input, gradOutput) + for i=1,nloop do + lstm:forward(linput) + lstm:zeroGradParameters() + lstm:backward(linput, lgradOutput) + end -- main test + collectgarbage() t:reset() + sys.tic() for i=1,nloop do - lstm:forward(input) + lstm:forward(linput) lstm:zeroGradParameters() - lstm:backward(input, gradOutput) + lstm:backward(linput, lgradOutput) end - lstm.testtime = t:time().real/nloop + lstm.testtime = sys.toc() / nloop -- t:time().real/nloop + collectgarbage() end - for i,name in ipairs{'fast','step','luarec','rec', 'luaseq', 'seq'} do + for i,name in ipairs{'fast','step','rec','luarec','luaseq','seq', 'vseq'} do print(name..' LSTM time: '..lstms[name].testtime..' seconds') end @@ -741,6 +757,7 @@ function rnnbigtest.LSTM() print("RecLSTM "..(lstms.fast.testtime/lstms.rec.testtime)..' faster than FastLSTM') print("SeqLSTM "..(lstms.rec.testtime/lstms.seq.testtime)..' faster than RecLSTM') print("SeqLSTM-C "..(lstms.luaseq.testtime/lstms.seq.testtime)..' faster than SeqLSTM-Lua') + print("VSeqLSTM "..(lstms.seq.testtime/lstms.vseq.testtime)..' faster than SeqLSTM-C') print("Memory test") @@ -750,9 +767,10 @@ function rnnbigtest.LSTM() lstm.clearsize = #torch.serialize(lstm) end - for i,name in ipairs{'fast','step','rec','seq'} do + for i,name in ipairs{'fast','step','rec','seq','vseq'} do print(name..' LSTM memory: '..lstms[name].fullsize/(1024*1024)..':'..lstms[name].clearsize/(1024*1024)..' MB') end + torch.setdefaulttensortype(ttype) end function rnnbigtest.GRU() @@ -818,6 +836,118 @@ function rnnbigtest.GRU() end end +function rnnbigtest.SeqLSTM_vs_VSeqLSTM() + local ttype = torch.getdefaulttensortype() + torch.setdefaulttensortype("torch.FloatTensor") + + local test = function(timesteps, bsize, embeddingSize, outputSize, ntests, vl_distribution) + -- The variable length distribution can be either + -- 'dense', 'diagonal' or 'worst-case' + local dense = vl_distribution == 'dense' + local sizes = torch.LongTensor(timesteps):fill(bsize) + if vl_distribution == 'diagonal' then + for i=1,timesteps do + sizes[i] = math.max(bsize - i + 1,1) + end + elseif vl_distribution == 'worst-case' then + sizes:fill(1) + sizes[1] = bsize + elseif not dense then + error('variable length distribution should either be "dense", "worst-case" or "diagonal"') + end + + local input_common = torch.Tensor(timesteps*512*embeddingSize):uniform() + local gradOutput_common = torch.Tensor(timesteps*512*outputSize):uniform() + + local input = torch.Tensor(timesteps, bsize, embeddingSize) + local mask = torch.ByteTensor(timesteps, bsize):zero() + local gradOutput = torch.Tensor(timesteps, bsize, outputSize) + + for i=1,bsize do + input[{{},{i,i},{}}]:copy(input_common[{{timesteps*(i-1)*embeddingSize+1,timesteps*i*embeddingSize}}]) + gradOutput[{{},{i,i},{}}]:copy(gradOutput_common[{{timesteps*(i-1)*outputSize+1,timesteps*i*outputSize}}]) + end + local input_flat = torch.Tensor(sizes:sum(), embeddingSize) + local gradOutput_flat = torch.Tensor(sizes:sum(), outputSize) + + local runningIdx = 1 + for i=1,timesteps do + input_flat[{{runningIdx,runningIdx+sizes[i]-1},{}}]:copy(input[{{i,i},{1,sizes[i]},{}}]) + gradOutput_flat[{{runningIdx,runningIdx+sizes[i]-1},{}}]:copy(gradOutput[{{i,i},{1,sizes[i]},{}}]) + if sizes[i] < bsize then + mask[{{i,i},{sizes[i]+1,bsize}}]:fill(1) + end + runningIdx = runningIdx + sizes[i] + end + local seqLSTM = nn.SeqLSTM(embeddingSize, outputSize) + if not dense then + seqLSTM:maskZero() + end + seqLSTM:remember('neither') + + local vseqLSTM = nn.VSeqLSTM(embeddingSize, outputSize) + vseqLSTM.weight = seqLSTM.weight + vseqLSTM.gradWeight = seqLSTM.gradWeight + vseqLSTM.gradBias = seqLSTM.gradBias + vseqLSTM.bias = seqLSTM.bias + + local output, gradInput + sys.tic() + for i=1,ntests do + if not dense then + seqLSTM:setZeroMask(mask) + end + output = seqLSTM:forward(input) + gradInput = seqLSTM:backward(input, gradOutput) + end + local seqLSTMDuration = sys.toc() + + local output_flat, gradInput_flat + local vinput = {input_flat, sizes} + sys.tic() + for i=1,ntests do + output_flat = vseqLSTM:forward(vinput) + gradInput_flat = vseqLSTM:backward(vinput, gradOutput_flat) + end + local vseqLSTMDuration = sys.toc() + + print('') + print('Results:') + print('\t- sequence length: ' .. timesteps) + print('\t- batch size: ' .. bsize) + print('\t- embeddingSize: ' .. embeddingSize) + print('\t- outputSize: ' .. outputSize) + print('\t- variable length distribution: ' .. vl_distribution) + print('\t- SPEEDUP: ' .. (seqLSTMDuration/vseqLSTMDuration)) + end + + local timesteps = 30 + local bsize = 30 + local embeddingSize = 128 + local outputSize = 128 + local ntests = 100 + test(timesteps, bsize, embeddingSize, outputSize, ntests, "dense") + test(timesteps, bsize, embeddingSize, outputSize, ntests, "diagonal") + test(timesteps, bsize, embeddingSize, outputSize, ntests, "worst-case") + + local timesteps = 16 + local bsize = 16 + local embeddingSize = 200 + local outputSize = 400 + local ntests = 100 + test(timesteps, bsize, embeddingSize, outputSize, ntests, "dense") + test(timesteps, bsize, embeddingSize, outputSize, ntests, "diagonal") + test(timesteps, bsize, embeddingSize, outputSize, ntests, "worst-case") + + local timesteps = 50 + local bsize = 1 + local embeddingSize = 128 + local outputSize = 128 + local ntests = 100 + test(timesteps, bsize, embeddingSize, outputSize, ntests, "dense") + + torch.setdefaulttensortype(ttype) +end function rnn.bigtest(tests) mytester = torch.Tester() diff --git a/test/test.lua b/test/test.lua index 660883e..6fc6f41 100644 --- a/test/test.lua +++ b/test/test.lua @@ -2828,6 +2828,110 @@ function rnntest.SeqLSTM_Lua_vs_C() end end +function rnntest.SeqLSTM_vs_VSeqLSTM() + local ttype = torch.getdefaulttensortype() + torch.setdefaulttensortype("torch.DoubleTensor") + + local test = function(timesteps, bsize, embeddingSize, outputSize, ntests, vl_distribution) + -- The variable length distribution can be either + -- 'dense', 'diagonal' or 'worst-case' + local sizes = torch.LongTensor(timesteps):fill(bsize) + if vl_distribution == 'diagonal' then + for i=1,timesteps do + sizes[i] = math.max(bsize - i + 1,1) + end + elseif vl_distribution == 'worst-case' then + sizes:fill(1) + sizes[1] = bsize + elseif vl_distribution ~= 'dense' then + error('variable length distribution should either be "dense", "worst-case" or "diagonal"') + end + + local input_common = torch.Tensor(timesteps*512*embeddingSize):uniform() + local gradOutput_common = torch.Tensor(timesteps*512*outputSize):uniform() + + local input = torch.Tensor(timesteps, bsize, embeddingSize):zero() + local mask = torch.ByteTensor(timesteps, bsize):zero() + local gradOutput = torch.Tensor(timesteps, bsize, outputSize):zero() + + for i=1,bsize do + input[{{},{i,i},{}}]:copy(input_common[{{timesteps*(i-1)*embeddingSize+1,timesteps*i*embeddingSize}}]) + gradOutput[{{},{i,i},{}}]:copy(gradOutput_common[{{timesteps*(i-1)*outputSize+1,timesteps*i*outputSize}}]) + end + local input_flat = torch.Tensor(sizes:sum(), embeddingSize):zero() + local gradOutput_flat = torch.Tensor(sizes:sum(), outputSize):zero() + + local runningIdx = 1 + for i=1,timesteps do + input_flat[{{runningIdx,runningIdx+sizes[i]-1},{}}]:copy(input[{{i,i},{1,sizes[i]},{}}]) + gradOutput_flat[{{runningIdx,runningIdx+sizes[i]-1},{}}]:copy(gradOutput[{{i,i},{1,sizes[i]},{}}]) + if sizes[i] < bsize then + mask[{{i,i},{sizes[i]+1,bsize}}]:fill(1) + end + runningIdx = runningIdx + sizes[i] + end + local seqLSTM = nn.SeqLSTM(embeddingSize, outputSize) + seqLSTM:maskZero() + seqLSTM:remember('neither') + seqLSTM:reset() + + local vseqLSTM = nn.VSeqLSTM(embeddingSize, outputSize) + vseqLSTM.weight:copy(seqLSTM.weight) + vseqLSTM.gradWeight:copy(seqLSTM.gradWeight:zero()) + vseqLSTM.gradBias:copy(seqLSTM.gradBias:zero()) + vseqLSTM.bias:copy(seqLSTM.bias) + + local output, gradInput + sys.tic() + for i=1,ntests do + seqLSTM:setZeroMask(mask) + output = seqLSTM:forward(input) + gradInput = seqLSTM:backward(input, gradOutput) + end + local seqLSTMDuration = sys.toc() + + local output_flat, gradInput_flat + local vinput = {input_flat, sizes} + sys.tic() + for i=1,ntests do + output_flat = vseqLSTM:forward(vinput) + gradInput_flat = vseqLSTM:backward(vinput, gradOutput_flat) + end + local vseqLSTMDuration = sys.toc() + local runningIdx = 1 + for i=1,timesteps do + local lout = torch.view(output[{{i,i},{1,sizes[i]},{}}], output[{{i,i},{1,sizes[i]},{}}]:nElement()) + local lout_flat = torch.view(output_flat[{{runningIdx,runningIdx+sizes[i]-1},{}}], output_flat[{{runningIdx,runningIdx+sizes[i]-1},{}}]:nElement()) + mytester:assertTensorEq(lout, lout_flat, 0.00001) + runningIdx = runningIdx + sizes[i] + end + + runningIdx = 1 + for i=1,timesteps do + local lgi = torch.view(gradInput[{{i,i},{1,sizes[i]},{}}], gradInput[{{i,i},{1,sizes[i]},{}}]:nElement()) + local lgi_flat = torch.view(gradInput_flat[{{runningIdx,runningIdx+sizes[i]-1},{}}], gradInput_flat[{{runningIdx,runningIdx+sizes[i]-1},{}}]:nElement()) + mytester:assertTensorEq(lgi, lgi_flat, 0.00001) + runningIdx = runningIdx + sizes[i] + end + local params, gradParams = seqLSTM:parameters() + local params2, gradParams2 = vseqLSTM:parameters() + for i=1,#params2 do + mytester:assertTensorEq(gradParams[i], gradParams2[i], 0.00001) + end + end + + local timesteps = 30 + local bsize = 30 + local embeddingSize = 4 + local outputSize = 5 + local ntests = 1 + test(timesteps, bsize, embeddingSize, outputSize, ntests, "dense") + test(timesteps, bsize, embeddingSize, outputSize, ntests, "diagonal") + test(timesteps, bsize, embeddingSize, outputSize, ntests, "worst-case") + + torch.setdefaulttensortype(ttype) +end + function rnntest.SeqLSTM_maskzero() -- tests that it works with non-masked inputs regardless of maskzero's value. -- Note that more maskzero = true tests with masked inputs are in SeqLSTM unit test. @@ -3954,14 +4058,18 @@ function rnntest.VariableLength_lstm() -- VL(LSTM): test forward local input = {} - local lstm, vl, input2, output + local lstm, vlstm, vl, vvl, input2, output if not testLM then for i=1,batchSize do - input[i] = torch.randn(torch.random(1,maxLength), hiddenSize) + input[i] = torch.randn(i,hiddenSize)--torch.random(1,maxLength), hiddenSize) end lstm = nn.SeqLSTM(hiddenSize, hiddenSize):maskZero() + vlstm = nn.VSeqLSTM(hiddenSize, hiddenSize) + vlstm.weight:copy(lstm.weight) + vlstm.bias:copy(lstm.bias) + input2 = torch.Tensor(maxLength, batchSize, hiddenSize):zero() else for i=1,batchSize do @@ -3972,19 +4080,39 @@ function rnntest.VariableLength_lstm() :add(nn.LookupTableMaskZero(nIndex, hiddenSize)) :add(nn.SeqLSTM(hiddenSize, hiddenSize):maskZero()) + vlstm = nn.Sequential() + :add(nn.LookupTableMaskZero(nIndex, hiddenSize)) + :add(nn.VSeqLSTM(hiddenSize, hiddenSize)) + + local p = lstm:parameters() + local vp = vlstm:parameters() + for i=1,#vp do + vp[i]:copy(p[i]) + end + input2 = torch.Tensor(maxLength, batchSize):zero() end + vll = nn.VariableLength(vlstm:clone(), lastOnly, true) vl = nn.VariableLength(lstm:clone(), lastOnly) local output = vl:forward(input) + local voutput = vll:forward(input) for i=1,batchSize do local seqlen = input[i]:size(1) input2:select(2,i):narrow(1,maxLength-seqlen+1,seqlen):copy(input[i]) end - lstm:setZeroMask(nn.utils.getZeroMaskSequence(input2)) + + if not lastOnly then + for i=1,batchSize do + mytester:assertTensorEq(output[i], voutput[i], 0.000001) + end + else + mytester:assertTensorEq(output, voutput, 0.000001) + end + local output2 = lstm:forward(input2) if not lastOnly then @@ -4021,11 +4149,22 @@ function rnntest.VariableLength_lstm() vl:zeroGradParameters() local gradInput = vl:backward(input, gradOutput) + vll:zeroGradParameters() + local vgradInput = vll:backward(input, gradOutput) + + + for i=1,batchSize do + mytester:assertTensorEq(gradInput[i], vgradInput[i], 0.000001) + end for i=1,batchSize do mytester:assert(input[i]:isSameSizeAs(gradInput[i])) end + for i=1,batchSize do + mytester:assert(input[i]:isSameSizeAs(vgradInput[i])) + end + lstm:zeroGradParameters() local gradInput2 = lstm:backward(input2, gradOutput2)