Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

C implementation of LSTM #31

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,20 @@ CMAKE_POLICY(VERSION 2.6)

FIND_PACKAGE(Torch REQUIRED)

ADD_SUBDIRECTORY(lib)

SET(BUILD_STATIC YES) # makes sure static targets are enabled in ADD_TORCH_PACKAGE

SET(CMAKE_C_FLAGS "--std=c99 -pedantic -Werror -Wall -Wextra -Wno-unused-function -D_GNU_SOURCE ${CMAKE_C_FLAGS}")
SET(src
init.c
)

FILE(STRINGS lib/THRNN/generic/THRNN.h THRNN_headers NEWLINE_CONSUME)
FILE(WRITE THRNN_h.lua "return [[")
FILE(APPEND THRNN_h.lua ${THRNN_headers})
FILE(APPEND THRNN_h.lua "]]")

SET(luasrc
init.lua
AbstractRecurrent.lua
Expand All @@ -36,6 +44,7 @@ SET(luasrc
SeqBLSTM.lua
SeqGRU.lua
SeqLSTM.lua
VSeqLSTM.lua
Sequencer.lua
SequencerCriterion.lua
test/bigtest.lua
Expand Down Expand Up @@ -75,6 +84,8 @@ SET(luasrc
deprecated/FastLSTM.lua
deprecated/GRU.lua
deprecated/LSTM.lua
THRNN.lua
THRNN_h.lua
)

ADD_TORCH_PACKAGE(rnn "${src}" "${luasrc}" "An RNN library for Torch")
Expand Down
6 changes: 3 additions & 3 deletions LookupTableMaskZero.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@ function LookupTableMaskZero:__init(nIndex, nOutput)
end

function LookupTableMaskZero:updateOutput(input)
self.weight[1]:zero()
self.weight[1]:zero()
if self.__input and (torch.type(self.__input) ~= torch.type(input)) then
self.__input = nil -- fixes old casting bug
end
self.__input = self.__input or input.new()
self.__input:resizeAs(input):add(input, 1)
return parent.updateOutput(self, self.__input)
return parent.updateOutput(self, self.__input)
end

function LookupTableMaskZero:accGradParameters(input, gradOutput, scale)
parent.accGradParameters(self, self.__input, gradOutput, scale)
parent.accGradParameters(self, self.__input, gradOutput, scale)
end

function LookupTableMaskZero:type(type, cache)
Expand Down
9 changes: 9 additions & 0 deletions Module.lua
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ function Module:setZeroMask(zeroMask)
end
end

function Module:setContext(context)
if self.modules then
for i, module in ipairs(self.modules) do
module:setContext(context)
end
end
self.__context = context
end

function Module:stepClone(shareParams, shareGradParams)
return self:sharedClone(shareParams, shareGradParams, true)
end
Expand Down
2 changes: 1 addition & 1 deletion SeqLSTM.lua
Original file line number Diff line number Diff line change
Expand Up @@ -457,4 +457,4 @@ function SeqLSTM:toRecLSTM()
end

return lstm
end
end
147 changes: 147 additions & 0 deletions THRNN.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
local ffi = require 'ffi'

local THRNN = {}


local generic_THRNN_h = require 'rnn.THRNN_h'
-- strip all lines starting with #
-- to remove preprocessor directives originally present
-- in THRNN.h
generic_THRNN_h = generic_THRNN_h:gsub("\n#[^\n]*", "")
generic_THRNN_h = generic_THRNN_h:gsub("^#[^\n]*\n", "")

-- THGenerator struct declaration copied from torch7/lib/TH/THRandom.h
local base_declarations = [[
typedef void THRNNState;

typedef struct {
unsigned long the_initial_seed;
int left;
int seeded;
unsigned long next;
unsigned long state[624]; /* the array for the state vector 624 = _MERSENNE_STATE_N */
double normal_x;
double normal_y;
double normal_rho;
int normal_is_valid;
} THGenerator;
]]

-- polyfill for LUA 5.1
if not package.searchpath then
local sep = package.config:sub(1,1)
function package.searchpath(mod, path)
mod = mod:gsub('%.', sep)
for m in path:gmatch('[^;]+') do
local nm = m:gsub('?', mod)
local f = io.open(nm, 'r')
if f then
f:close()
return nm
end
end
end
end

-- load libTHRNN
THRNN.C = ffi.load(package.searchpath('libTHRNN', package.cpath))

ffi.cdef(base_declarations)

-- expand macros, allow to use original lines from lib/THRNN/generic/THRNN.h
local preprocessed = string.gsub(generic_THRNN_h, 'TH_API void THRNN_%(([%a%d_]+)%)', 'void THRNN_TYPE%1')

local replacements =
{
{
['TYPE'] = 'Double',
['real'] = 'double',
['THTensor'] = 'THDoubleTensor',
['THIndexTensor'] = 'THLongTensor',
['THIntegerTensor'] = 'THIntTensor',
['THIndex_t'] = 'long',
['THInteger_t'] = 'int'
},
{
['TYPE'] = 'Float',
['real'] = 'float',
['THTensor'] = 'THFloatTensor',
['THIndexTensor'] = 'THLongTensor',
['THIntegerTensor'] = 'THIntTensor',
['THIndex_t'] = 'long',
['THInteger_t'] = 'int'
}
}

-- gsub(s, 'real', 'float') changes accreal to accfloat.
-- typedef accfloat ahead of time.
ffi.cdef("typedef double accfloat;")
-- gsub(s, 'real', 'double') changes accreal to accfloat.
-- typedef accdouble ahead of time
ffi.cdef("typedef double accdouble;")

for i=1,#replacements do
local r = replacements[i]
local s = preprocessed
for k,v in pairs(r) do
s = string.gsub(s, k, v)
end
ffi.cdef(s)
end

THRNN.NULL = ffi.NULL or nil

function THRNN.getState()
return ffi.NULL or nil
end

function THRNN.optionalTensor(t)
return t and t:cdata() or THRNN.NULL
end

local function extract_function_names(s)
local t = {}
for n in string.gmatch(s, 'TH_API void THRNN_%(([%a%d_]+)%)') do
t[#t+1] = n
end
return t
end

function THRNN.bind(lib, base_names, type_name, state_getter)
local ftable = {}
local prefix = 'THRNN_' .. type_name
for i,n in ipairs(base_names) do
-- use pcall since some libs might not support all functions (e.g. cunn)
local ok,v = pcall(function() return lib[prefix .. n] end)
if ok then
ftable[n] = function(...) v(state_getter(), ...) end -- implicitely add state
else
print('not found: ' .. prefix .. n .. v)
end
end
return ftable
end

-- build function table
local function_names = extract_function_names(generic_THRNN_h)

THRNN.kernels = {}
THRNN.kernels['torch.FloatTensor'] = THRNN.bind(THRNN.C, function_names, 'Float', THRNN.getState)
THRNN.kernels['torch.DoubleTensor'] = THRNN.bind(THRNN.C, function_names, 'Double', THRNN.getState)

torch.getmetatable('torch.FloatTensor').THRNN = THRNN.kernels['torch.FloatTensor']
torch.getmetatable('torch.DoubleTensor').THRNN = THRNN.kernels['torch.DoubleTensor']

function THRNN.runKernel(f, type, ...)
local ftable = THRNN.kernels[type]
if not ftable then
error('Unsupported tensor type: '..type)
end
local f = ftable[f]
if not f then
error(string.format("Function '%s' not found for tensor type '%s'.", f, type))
end
f(...)
end

return THRNN
Loading