Skip to content

Commit

Permalink
C implementation of LSTM -- REVIEWABLE
Browse files Browse the repository at this point in the history
* New "sparse/size" representation
* Full LSTM in C
* VSeqLSTM to wrap this data representation + C implementation
* Augmentation of the VariableLength decorator with this data representation from an array of tensors
* unit tests
* speed tests
  • Loading branch information
nkoumchatzky committed Jun 3, 2017
1 parent fee152c commit 8f8677a
Show file tree
Hide file tree
Showing 16 changed files with 2,095 additions and 54 deletions.
11 changes: 11 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,20 @@ CMAKE_POLICY(VERSION 2.6)

FIND_PACKAGE(Torch REQUIRED)

ADD_SUBDIRECTORY(lib)

SET(BUILD_STATIC YES) # makes sure static targets are enabled in ADD_TORCH_PACKAGE

SET(CMAKE_C_FLAGS "--std=c99 -pedantic -Werror -Wall -Wextra -Wno-unused-function -D_GNU_SOURCE ${CMAKE_C_FLAGS}")
SET(src
init.c
)

FILE(STRINGS lib/THRNN/generic/THRNN.h THRNN_headers NEWLINE_CONSUME)
FILE(WRITE THRNN_h.lua "return [[")
FILE(APPEND THRNN_h.lua ${THRNN_headers})
FILE(APPEND THRNN_h.lua "]]")

SET(luasrc
init.lua
AbstractRecurrent.lua
Expand All @@ -36,6 +44,7 @@ SET(luasrc
SeqBLSTM.lua
SeqGRU.lua
SeqLSTM.lua
VSeqLSTM.lua
Sequencer.lua
SequencerCriterion.lua
test/bigtest.lua
Expand Down Expand Up @@ -75,6 +84,8 @@ SET(luasrc
deprecated/FastLSTM.lua
deprecated/GRU.lua
deprecated/LSTM.lua
THRNN.lua
THRNN_h.lua
)

ADD_TORCH_PACKAGE(rnn "${src}" "${luasrc}" "An RNN library for Torch")
Expand Down
6 changes: 3 additions & 3 deletions LookupTableMaskZero.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@ function LookupTableMaskZero:__init(nIndex, nOutput)
end

function LookupTableMaskZero:updateOutput(input)
self.weight[1]:zero()
self.weight[1]:zero()
if self.__input and (torch.type(self.__input) ~= torch.type(input)) then
self.__input = nil -- fixes old casting bug
end
self.__input = self.__input or input.new()
self.__input:resizeAs(input):add(input, 1)
return parent.updateOutput(self, self.__input)
return parent.updateOutput(self, self.__input)
end

function LookupTableMaskZero:accGradParameters(input, gradOutput, scale)
parent.accGradParameters(self, self.__input, gradOutput, scale)
parent.accGradParameters(self, self.__input, gradOutput, scale)
end

function LookupTableMaskZero:type(type, cache)
Expand Down
9 changes: 9 additions & 0 deletions Module.lua
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ function Module:setZeroMask(zeroMask)
end
end

function Module:setContext(context)
if self.modules then
for i, module in ipairs(self.modules) do
module:setContext(context)
end
end
self.__context = context
end

function Module:stepClone(shareParams, shareGradParams)
return self:sharedClone(shareParams, shareGradParams, true)
end
Expand Down
2 changes: 1 addition & 1 deletion SeqLSTM.lua
Original file line number Diff line number Diff line change
Expand Up @@ -457,4 +457,4 @@ function SeqLSTM:toRecLSTM()
end

return lstm
end
end
147 changes: 147 additions & 0 deletions THRNN.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
local ffi = require 'ffi'

local THRNN = {}


local generic_THRNN_h = require 'rnn.THRNN_h'
-- strip all lines starting with #
-- to remove preprocessor directives originally present
-- in THRNN.h
generic_THRNN_h = generic_THRNN_h:gsub("\n#[^\n]*", "")
generic_THRNN_h = generic_THRNN_h:gsub("^#[^\n]*\n", "")

-- THGenerator struct declaration copied from torch7/lib/TH/THRandom.h
local base_declarations = [[
typedef void THRNNState;
typedef struct {
unsigned long the_initial_seed;
int left;
int seeded;
unsigned long next;
unsigned long state[624]; /* the array for the state vector 624 = _MERSENNE_STATE_N */
double normal_x;
double normal_y;
double normal_rho;
int normal_is_valid;
} THGenerator;
]]

-- polyfill for LUA 5.1
if not package.searchpath then
local sep = package.config:sub(1,1)
function package.searchpath(mod, path)
mod = mod:gsub('%.', sep)
for m in path:gmatch('[^;]+') do
local nm = m:gsub('?', mod)
local f = io.open(nm, 'r')
if f then
f:close()
return nm
end
end
end
end

-- load libTHRNN
THRNN.C = ffi.load(package.searchpath('libTHRNN', package.cpath))

ffi.cdef(base_declarations)

-- expand macros, allow to use original lines from lib/THRNN/generic/THRNN.h
local preprocessed = string.gsub(generic_THRNN_h, 'TH_API void THRNN_%(([%a%d_]+)%)', 'void THRNN_TYPE%1')

local replacements =
{
{
['TYPE'] = 'Double',
['real'] = 'double',
['THTensor'] = 'THDoubleTensor',
['THIndexTensor'] = 'THLongTensor',
['THIntegerTensor'] = 'THIntTensor',
['THIndex_t'] = 'long',
['THInteger_t'] = 'int'
},
{
['TYPE'] = 'Float',
['real'] = 'float',
['THTensor'] = 'THFloatTensor',
['THIndexTensor'] = 'THLongTensor',
['THIntegerTensor'] = 'THIntTensor',
['THIndex_t'] = 'long',
['THInteger_t'] = 'int'
}
}

-- gsub(s, 'real', 'float') changes accreal to accfloat.
-- typedef accfloat ahead of time.
ffi.cdef("typedef double accfloat;")
-- gsub(s, 'real', 'double') changes accreal to accfloat.
-- typedef accdouble ahead of time
ffi.cdef("typedef double accdouble;")

for i=1,#replacements do
local r = replacements[i]
local s = preprocessed
for k,v in pairs(r) do
s = string.gsub(s, k, v)
end
ffi.cdef(s)
end

THRNN.NULL = ffi.NULL or nil

function THRNN.getState()
return ffi.NULL or nil
end

function THRNN.optionalTensor(t)
return t and t:cdata() or THRNN.NULL
end

local function extract_function_names(s)
local t = {}
for n in string.gmatch(s, 'TH_API void THRNN_%(([%a%d_]+)%)') do
t[#t+1] = n
end
return t
end

function THRNN.bind(lib, base_names, type_name, state_getter)
local ftable = {}
local prefix = 'THRNN_' .. type_name
for i,n in ipairs(base_names) do
-- use pcall since some libs might not support all functions (e.g. cunn)
local ok,v = pcall(function() return lib[prefix .. n] end)
if ok then
ftable[n] = function(...) v(state_getter(), ...) end -- implicitely add state
else
print('not found: ' .. prefix .. n .. v)
end
end
return ftable
end

-- build function table
local function_names = extract_function_names(generic_THRNN_h)

THRNN.kernels = {}
THRNN.kernels['torch.FloatTensor'] = THRNN.bind(THRNN.C, function_names, 'Float', THRNN.getState)
THRNN.kernels['torch.DoubleTensor'] = THRNN.bind(THRNN.C, function_names, 'Double', THRNN.getState)

torch.getmetatable('torch.FloatTensor').THRNN = THRNN.kernels['torch.FloatTensor']
torch.getmetatable('torch.DoubleTensor').THRNN = THRNN.kernels['torch.DoubleTensor']

function THRNN.runKernel(f, type, ...)
local ftable = THRNN.kernels[type]
if not ftable then
error('Unsupported tensor type: '..type)
end
local f = ftable[f]
if not f then
error(string.format("Function '%s' not found for tensor type '%s'.", f, type))
end
f(...)
end

return THRNN
Loading

0 comments on commit 8f8677a

Please sign in to comment.