diff --git a/Project.toml b/Project.toml index 265687a..1de0954 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RecurrentLayers" uuid = "78449bcf-6750-4b78-9e82-63d4a1ccdf8c" authors = ["Francesco Martinuzzi"] -version = "0.1.3" +version = "0.1.4" [deps] Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" diff --git a/src/RecurrentLayers.jl b/src/RecurrentLayers.jl index a527e60..7a8d67f 100644 --- a/src/RecurrentLayers.jl +++ b/src/RecurrentLayers.jl @@ -3,6 +3,36 @@ module RecurrentLayers using Flux import Flux: _size_check, _match_eltype, chunk, create_bias, zeros_like import Flux: glorot_uniform +import Flux: initialstates + +abstract type AbstractRecurrentCell end +abstract type AbstractDoubleRecurrentCell <: AbstractRecurrentCell end + +function initialstates(rcell::AbstractRecurrentCell) + return zeros_like(rcell.Wh, size(rcell.Wh, 2)) +end + +function initialstates(rcell::AbstractDoubleRecurrentCell) + state = zeros_like(rcell.Wh, size(rcell.Wh, 2)) + second_state = zeros_like(rcell.Wh, size(rcell.Wh, 2)) + return state, second_state +end + +function (rcell::AbstractRecurrentCell)(inp::AbstractVecOrMat) + state = initialstates(rcell) + return rcell(inp, state) +end + +abstract type AbstractRecurrentLayer end + +function initialstates(rlayer::AbstractRecurrentLayer) + return initialstates(rlayer.cell) +end + +function (rlayer::AbstractRecurrentLayer)(inp::AbstractVecOrMat) + state = initialstates(rlayer) + return rcell(inp, state) +end export MGUCell, LiGRUCell, IndRNNCell, RANCell, LightRUCell, RHNCell, RHNCellUnit, NASCell, MUT1Cell, MUT2Cell, MUT3Cell, SCRNCell, PeepholeLSTMCell, diff --git a/src/fastrnn_cell.jl b/src/fastrnn_cell.jl index be40cb4..1272f04 100644 --- a/src/fastrnn_cell.jl +++ b/src/fastrnn_cell.jl @@ -1,5 +1,5 @@ #https://arxiv.org/abs/1901.02358 -struct FastRNNCell{I, H, V, A, B, F} +struct FastRNNCell{I, H, V, A, B, F} <: AbstractRecurrentCell Wi::I Wh::H bias::V @@ -53,11 +53,6 @@ function FastRNNCell((input_size, hidden_size)::Pair, activation=tanh_fast; return FastRNNCell(Wi, Wh, b, alpha, beta, activation) end -function (fastrnn::FastRNNCell)(inp::AbstractVecOrMat) - state = zeros_like(inp, size(fastrnn.Wh, 2)) - return fastrnn(inp, state) -end - function (fastrnn::FastRNNCell)(inp::AbstractVecOrMat, state) #checks _size_check(fastrnn, inp, 1 => size(fastrnn.Wi,2)) @@ -77,7 +72,7 @@ Base.show(io::IO, fastrnn::FastRNNCell) = print(io, "FastRNNCell(", size(fastrnn.Wi, 2), " => ", size(fastrnn.Wi, 1) ÷ 2, ")") -struct FastRNN{M} +struct FastRNN{M} <: AbstractRecurrentLayer cell::M end @@ -114,11 +109,6 @@ function FastRNN((input_size, hidden_size)::Pair, activation = tanh_fast; cell = FastRNNCell(input_size => hidden_size, activation; kwargs...) return FastRNN(cell) end - -function (fastrnn::FastRNN)(inp) - state = zeros_like(inp, size(fastrnn.cell.Wh, 2)) - return fastrnn(inp, state) -end function (fastrnn::FastRNN)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 @@ -131,7 +121,7 @@ function (fastrnn::FastRNN)(inp, state) end -struct FastGRNNCell{I, H, V, A, B, F} +struct FastGRNNCell{I, H, V, A, B, F} <: AbstractRecurrentCell Wi::I Wh::H bias::V @@ -186,11 +176,6 @@ function FastGRNNCell((input_size, hidden_size)::Pair, activation=tanh_fast; return FastGRNNCell(Wi, Wh, b, alpha, beta, activation) end -function (fastgrnn::FastGRNNCell)(inp::AbstractVecOrMat) - state = zeros_like(inp, size(fastgrnn.Wh, 2)) - return fastgrnn(inp, state) -end - function (fastgrnn::FastGRNNCell)(inp::AbstractVecOrMat, state) #checks _size_check(fastgrnn, inp, 1 => size(fastgrnn.Wi,2)) @@ -214,7 +199,7 @@ Base.show(io::IO, fastgrnn::FastGRNNCell) = print(io, "FastGRNNCell(", size(fastgrnn.Wi, 2), " => ", size(fastgrnn.Wi, 1) ÷ 2, ")") -struct FastGRNN{M} +struct FastGRNN{M} <: AbstractRecurrentLayer cell::M end @@ -252,11 +237,6 @@ function FastGRNN((input_size, hidden_size)::Pair, activation = tanh_fast; cell = FastGRNNCell(input_size => hidden_size, activation; kwargs...) return FastGRNN(cell) end - -function (fastgrnn::FastGRNN)(inp) - state = zeros_like(inp, size(fastgrnn.cell.Wh, 2)) - return fastgrnn(inp, state) -end function (fastgrnn::FastGRNN)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 diff --git a/src/indrnn_cell.jl b/src/indrnn_cell.jl index 00400eb..36f31f7 100644 --- a/src/indrnn_cell.jl +++ b/src/indrnn_cell.jl @@ -1,8 +1,8 @@ #https://arxiv.org/pdf/1803.04831 -struct IndRNNCell{F,I,H,V} +struct IndRNNCell{F,I,H,V} <: AbstractRecurrentCell σ::F Wi::I - u::H + Wh::H b::V end @@ -41,20 +41,15 @@ function IndRNNCell((input_size, hidden_size)::Pair, σ=relu; init_recurrent_kernel = glorot_uniform, bias = true) Wi = init_kernel(hidden_size, input_size) - u = init_recurrent_kernel(hidden_size) + Wh = init_recurrent_kernel(hidden_size) b = create_bias(Wi, bias, size(Wi, 1)) - return IndRNNCell(σ, Wi, u, b) -end - -function (indrnn::IndRNNCell)(x::AbstractVecOrMat) - state = zeros_like(x, size(indrnn.u, 1)) - return indrnn(x, state) + return IndRNNCell(σ, Wi, Wh, b) end function (indrnn::IndRNNCell)(inp::AbstractVecOrMat, state::AbstractVecOrMat) _size_check(indrnn, inp, 1 => size(indrnn.Wi, 2)) σ = NNlib.fast_act(indrnn.σ, inp) - state = σ.(indrnn.Wi*inp .+ indrnn.u .* state .+ indrnn.b) + state = σ.(indrnn.Wi*inp .+ indrnn.Wh .* state .+ indrnn.b) return state end @@ -64,7 +59,7 @@ function Base.show(io::IO, indrnn::IndRNNCell) print(io, ")") end -struct IndRNN{M} +struct IndRNN{M} <: AbstractRecurrentLayer cell::M end @@ -95,11 +90,6 @@ function IndRNN((input_size, hidden_size)::Pair, σ = tanh; kwargs...) return IndRNN(cell) end -function (indrnn::IndRNN)(inp) - state = zeros_like(inp, size(indrnn.cell.u, 1)) - return indrnn(inp, state) -end - function (indrnn::IndRNN)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 new_state = [] diff --git a/src/lightru_cell.jl b/src/lightru_cell.jl index 1f2f081..9824d8b 100644 --- a/src/lightru_cell.jl +++ b/src/lightru_cell.jl @@ -1,5 +1,5 @@ #https://www.mdpi.com/2079-9292/13/16/3204 -struct LightRUCell{I,H,V} +struct LightRUCell{I,H,V} <: AbstractRecurrentCell Wi::I Wh::H bias::V @@ -47,14 +47,9 @@ function LightRUCell((input_size, hidden_size)::Pair; return LightRUCell(Wi, Wh, b) end -function (lru::LightRUCell)(inp::AbstractVecOrMat) - state = zeros_like(inp, size(lru.Wh, 2)) - return lru(inp, state) -end - -function (lru::LightRUCell)(inp::AbstractVecOrMat, state) - _size_check(lru, inp, 1 => size(lru.Wi,2)) - Wi, Wh, b = lru.Wi, lru.Wh, lru.bias +function (lightru::LightRUCell)(inp::AbstractVecOrMat, state) + _size_check(lightru, inp, 1 => size(lightru.Wi,2)) + Wi, Wh, b = lightru.Wi, lightru.Wh, lightru.bias #split gxs = chunk(Wi * inp, 2, dims=1) @@ -66,12 +61,12 @@ function (lru::LightRUCell)(inp::AbstractVecOrMat, state) return new_state end -Base.show(io::IO, lru::LightRUCell) = - print(io, "LightRUCell(", size(lru.Wi, 2), " => ", size(lru.Wi, 1)÷2, ")") +Base.show(io::IO, lightru::LightRUCell) = + print(io, "LightRUCell(", size(lightru.Wi, 2), " => ", size(lightru.Wi, 1)÷2, ")") -struct LightRU{M} +struct LightRU{M} <: AbstractRecurrentLayer cell::M end @@ -104,16 +99,11 @@ function LightRU((input_size, hidden_size)::Pair; kwargs...) return LightRU(cell) end -function (lru::LightRU)(inp) - state = zeros_like(inp, size(lru.cell.Wh, 2)) - return lru(inp, state) -end - -function (lru::LightRU)(inp, state) +function (lightru::LightRU)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 new_state = [] for inp_t in eachslice(inp, dims=2) - state = lru.cell(inp_t, state) + state = lightru.cell(inp_t, state) new_state = vcat(new_state, [state]) end return stack(new_state, dims=2) diff --git a/src/ligru_cell.jl b/src/ligru_cell.jl index b09e842..88cd2a6 100644 --- a/src/ligru_cell.jl +++ b/src/ligru_cell.jl @@ -1,5 +1,5 @@ #https://arxiv.org/pdf/1803.10225 -struct LiGRUCell{I, H, V} +struct LiGRUCell{I, H, V} <: AbstractRecurrentCell Wi::I Wh::H bias::V @@ -50,11 +50,6 @@ function LiGRUCell((input_size, hidden_size)::Pair; return LiGRUCell(Wi, Wh, b) end -function (ligru::LiGRUCell)(inp::AbstractVecOrMat) - state = zeros_like(inp, size(ligru.Wh, 2)) - return ligru(inp, state) -end - function (ligru::LiGRUCell)(inp::AbstractVecOrMat, state) _size_check(ligru, inp, 1 => size(ligru.Wi,2)) Wi, Wh, b = ligru.Wi, ligru.Wh, ligru.bias @@ -69,7 +64,7 @@ function (ligru::LiGRUCell)(inp::AbstractVecOrMat, state) end -struct LiGRU{M} +struct LiGRU{M} <: AbstractRecurrentLayer cell::M end @@ -104,11 +99,6 @@ function LiGRU((input_size, hidden_size)::Pair; kwargs...) return LiGRU(cell) end -function (ligru::LiGRU)(inp) - state = zeros_like(inp, size(ligru.cell.Wh, 2)) - return ligru(inp, state) -end - function (ligru::LiGRU)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 new_state = [] diff --git a/src/mgu_cell.jl b/src/mgu_cell.jl index 3d81106..483ad7f 100644 --- a/src/mgu_cell.jl +++ b/src/mgu_cell.jl @@ -1,5 +1,5 @@ #https://arxiv.org/pdf/1603.09420 -struct MGUCell{I, H, V} +struct MGUCell{I, H, V} <: AbstractRecurrentCell Wi::I Wh::H bias::V @@ -48,11 +48,6 @@ function MGUCell((input_size, hidden_size)::Pair; return MGUCell(Wi, Wh, b) end -function (mgu::MGUCell)(inp::AbstractVecOrMat) - state = zeros_like(inp, size(mgu.Wh, 2)) - return mgu(inp, state) -end - function (mgu::MGUCell)(inp::AbstractVecOrMat, state) _size_check(mgu, inp, 1 => size(mgu.Wi,2)) Wi, Wh, b = mgu.Wi, mgu.Wh, mgu.bias @@ -70,7 +65,7 @@ Base.show(io::IO, mgu::MGUCell) = print(io, "MGUCell(", size(mgu.Wi, 2), " => ", size(mgu.Wi, 1) ÷ 2, ")") -struct MGU{M} +struct MGU{M} <: AbstractRecurrentLayer cell::M end @@ -102,11 +97,6 @@ function MGU((input_size, hidden_size)::Pair; kwargs...) cell = MGUCell(input_size => hidden_size; kwargs...) return MGU(cell) end - -function (mgu::MGU)(inp) - state = zeros_like(inp, size(mgu.cell.Wh, 2)) - return mgu(inp, state) -end function (mgu::MGU)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 diff --git a/src/mut_cell.jl b/src/mut_cell.jl index 835ed3a..2d3a643 100644 --- a/src/mut_cell.jl +++ b/src/mut_cell.jl @@ -1,5 +1,5 @@ #https://proceedings.mlr.press/v37/jozefowicz15.pdf -struct MUT1Cell{I, H, V} +struct MUT1Cell{I, H, V} <: AbstractRecurrentCell Wi::I Wh::H bias::V @@ -49,11 +49,6 @@ function MUT1Cell((input_size, hidden_size)::Pair; return MUT1Cell(Wi, Wh, b) end -function (mut::MUT1Cell)(inp::AbstractVecOrMat) - state = zeros_like(inp, size(mut.Wh, 2)) - return mut(inp, state) -end - function (mut::MUT1Cell)(inp::AbstractVecOrMat, state) _size_check(mut, inp, 1 => size(mut.Wi,2)) Wi, Wh, b = mut.Wi, mut.Wh, mut.bias @@ -73,7 +68,7 @@ end Base.show(io::IO, mut::MUT1Cell) = print(io, "MUT1Cell(", size(mut.Wi, 2), " => ", size(mut.Wi, 1) ÷ 3, ")") -struct MUT1{M} +struct MUT1{M} <: AbstractRecurrentLayer cell::M end @@ -107,11 +102,6 @@ function MUT1((input_size, hidden_size)::Pair; kwargs...) return MUT1(cell) end -function (mut::MUT1)(inp) - state = zeros_like(inp, size(mut.cell.Wh, 2)) - return mut(inp, state) -end - function (mut::MUT1)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 new_state = [] @@ -123,8 +113,7 @@ function (mut::MUT1)(inp, state) end - -struct MUT2Cell{I, H, V} +struct MUT2Cell{I, H, V} <: AbstractRecurrentCell Wi::I Wh::H bias::V @@ -174,11 +163,6 @@ function MUT2Cell((input_size, hidden_size)::Pair; return MUT2Cell(Wi, Wh, b) end -function (mut::MUT2Cell)(inp::AbstractVecOrMat) - state = zeros_like(inp, size(mut.Wh, 2)) - return mut(inp, state) -end - function (mut::MUT2Cell)(inp::AbstractVecOrMat, state) _size_check(mut, inp, 1 => size(mut.Wi,2)) Wi, Wh, b = mut.Wi, mut.Wh, mut.bias @@ -198,7 +182,7 @@ Base.show(io::IO, mut::MUT2Cell) = print(io, "MUT2Cell(", size(mut.Wi, 2), " => ", size(mut.Wi, 1) ÷ 3, ")") -struct MUT2{M} +struct MUT2{M} <: AbstractRecurrentLayer cell::M end @@ -231,11 +215,6 @@ function MUT2((input_size, hidden_size)::Pair; kwargs...) cell = MUT2Cell(input_size => hidden_size; kwargs...) return MUT2(cell) end - -function (mut::MUT2)(inp) - state = zeros_like(inp, size(mut.cell.Wh, 2)) - return mut(inp, state) -end function (mut::MUT2)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 @@ -248,7 +227,7 @@ function (mut::MUT2)(inp, state) end -struct MUT3Cell{I, H, V} +struct MUT3Cell{I, H, V} <: AbstractRecurrentCell Wi::I Wh::H bias::V @@ -298,11 +277,6 @@ function MUT3Cell((input_size, hidden_size)::Pair; return MUT3Cell(Wi, Wh, b) end -function (mut::MUT3Cell)(inp::AbstractVecOrMat) - state = zeros_like(inp, size(mut.Wh, 2)) - return mut(inp, state) -end - function (mut::MUT3Cell)(inp::AbstractVecOrMat, state) _size_check(mut, inp, 1 => size(mut.Wi,2)) Wi, Wh, b = mut.Wi, mut.Wh, mut.bias @@ -320,7 +294,7 @@ end Base.show(io::IO, mut::MUT3Cell) = print(io, "MUT3Cell(", size(mut.Wi, 2), " => ", size(mut.Wi, 1) ÷ 3, ")") -struct MUT3{M} +struct MUT3{M} <: AbstractRecurrentLayer cell::M end @@ -353,11 +327,6 @@ function MUT3((input_size, hidden_size)::Pair; kwargs...) cell = MUT3Cell(input_size => hidden_size; kwargs...) return MUT3(cell) end - -function (mut::MUT3)(inp) - state = zeros_like(inp, size(mut.cell.Wh, 2)) - return mut(inp, state) -end function (mut::MUT3)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 diff --git a/src/nas_cell.jl b/src/nas_cell.jl index b5705fb..116b509 100644 --- a/src/nas_cell.jl +++ b/src/nas_cell.jl @@ -23,7 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -struct NASCell{I,H,V} +struct NASCell{I,H,V} <: AbstractDoubleRecurrentCell Wi::I Wh::H bias::V @@ -90,12 +90,6 @@ function NASCell((input_size, hidden_size)::Pair; return NASCell(Wi, Wh, b) end -function (nas::NASCell)(inp::AbstractVecOrMat) - state = zeros_like(inp, size(nas.Wh, 2)) - c_state = zeros_like(state) - return nas(inp, (state, c_state)) -end - function (nas::NASCell)(inp::AbstractVecOrMat, (state, c_state)) _size_check(nas, inp, 1 => size(nas.Wi,2)) Wi, Wh, b = nas.Wi, nas.Wh, nas.bias @@ -136,7 +130,7 @@ Base.show(io::IO, nas::NASCell) = print(io, "NASCell(", size(nas.Wi, 2), " => ", size(nas.Wi, 1)÷8, ")") -struct NAS{M} +struct NAS{M} <: AbstractRecurrentLayer cell::M end @@ -190,12 +184,6 @@ function NAS((input_size, hidden_size)::Pair; kwargs...) return NAS(cell) end -function (nas::NAS)(inp) - state = zeros_like(inp, size(nas.cell.Wh, 2)) - c_state = zeros_like(state) - return nas(inp, (state, c_state)) -end - function (nas::NAS)(inp, (state, c_state)) @assert ndims(inp) == 2 || ndims(inp) == 3 new_state = [] diff --git a/src/peepholelstm_cell.jl b/src/peepholelstm_cell.jl index c675844..d7e18f8 100644 --- a/src/peepholelstm_cell.jl +++ b/src/peepholelstm_cell.jl @@ -1,5 +1,5 @@ #https://www.jmlr.org/papers/volume3/gers02a/gers02a.pdf -struct PeepholeLSTMCell{I, H, V} +struct PeepholeLSTMCell{I, H, V} <: AbstractDoubleRecurrentCell Wi::I Wh::H bias::V @@ -62,12 +62,6 @@ function PeepholeLSTMCell( return cell end -function (lstm::PeepholeLSTMCell)(inp::AbstractVecOrMat) - state = zeros_like(inp, size(lstm.Wh, 2)) - c_state = zeros_like(state) - return lstm(inp, (state, c_state)) -end - function (lstm::PeepholeLSTMCell)(inp::AbstractVecOrMat, (state, c_state)) _size_check(lstm, inp, 1 => size(lstm.Wi, 2)) @@ -84,7 +78,7 @@ Base.show(io::IO, lstm::PeepholeLSTMCell) = -struct PeepholeLSTM{M} +struct PeepholeLSTM{M} <: AbstractRecurrentLayer cell::M end @@ -120,12 +114,6 @@ function PeepholeLSTM((input_size, hidden_size)::Pair; kwargs...) return PeepholeLSTM(cell) end -function (lstm::PeepholeLSTM)(inp) - state = zeros_like(inp, size(lstm.cell.Wh, 2)) - c_state = zeros_like(state) - return lstm(inp, (state, c_state)) -end - function (lstm::PeepholeLSTM)(inp, (state, c_state)) @assert ndims(inp) == 2 || ndims(inp) == 3 new_state = [] diff --git a/src/ran_cell.jl b/src/ran_cell.jl index 6482cb5..aab0176 100644 --- a/src/ran_cell.jl +++ b/src/ran_cell.jl @@ -1,5 +1,5 @@ #https://arxiv.org/pdf/1705.07393 -struct RANCell{I,H,V} +struct RANCell{I,H,V} <: AbstractDoubleRecurrentCell Wi::I Wh::H bias::V @@ -7,7 +7,6 @@ end Flux.@layer RANCell - @doc raw""" RANCell((input_size => hidden_size)::Pair; init_kernel = glorot_uniform, @@ -76,12 +75,6 @@ function RANCell((input_size, hidden_size)::Pair; return RANCell(Wi, Wh, b) end -function (ran::RANCell)(inp::AbstractVecOrMat) - state = zeros_like(inp, size(ran.Wh, 2)) - c_state = zeros_like(state) - return ran(inp, (state, c_state)) -end - function (ran::RANCell)(inp::AbstractVecOrMat, (state, c_state)) _size_check(ran, inp, 1 => size(ran.Wi,2)) Wi, Wh, b = ran.Wi, ran.Wh, ran.bias @@ -102,7 +95,7 @@ Base.show(io::IO, ran::RANCell) = print(io, "RANCell(", size(ran.Wi, 2), " => ", size(ran.Wi, 1)÷3, ")") -struct RAN{M} +struct RAN{M} <: AbstractRecurrentLayer cell::M end @@ -142,12 +135,6 @@ function RAN((input_size, hidden_size)::Pair; kwargs...) return RAN(cell) end -function (ran::RAN)(inp) - state = zeros_like(inp, size(ran.cell.Wh, 2)) - c_state = zeros_like(state) - return ran(inp, (state, c_state)) -end - function (ran::RAN)(inp, (state, c_state)) @assert ndims(inp) == 2 || ndims(inp) == 3 new_state = [] diff --git a/src/scrn_cell.jl b/src/scrn_cell.jl index 3fa6be7..5a929a6 100644 --- a/src/scrn_cell.jl +++ b/src/scrn_cell.jl @@ -1,5 +1,5 @@ #https://arxiv.org/pdf/1412.7753 -struct SCRNCell{I,H,C,V,A} +struct SCRNCell{I,H,C,V,A} <: AbstractDoubleRecurrentCell Wi::I Wh::H Wc::C @@ -80,7 +80,7 @@ Base.show(io::IO, scrn::SCRNCell) = print(io, "SCRNCell(", size(scrn.Wi, 2), " => ", size(scrn.Wi, 1)÷3, ")") -struct SCRN{M} +struct SCRN{M} <: AbstractRecurrentLayer cell::M end diff --git a/src/sru_cell.jl b/src/sru_cell.jl index 57c65e6..1c17014 100644 --- a/src/sru_cell.jl +++ b/src/sru_cell.jl @@ -46,4 +46,4 @@ function (sru::SRUCell)(inp::AbstractVecOrMat, (state, c_state)) end Base.show(io::IO, sru::SRUCell) = - print(io, "SRUCell(", size(sru.Wi, 2), " => ", size(sru.Wi, 1)÷2, ")") \ No newline at end of file + print(io, "SRUCell(", size(sru.Wi, 2), " => ", size(sru.Wi, 1)÷2, ")")