diff --git a/Project.toml b/Project.toml index 1de0954..4874707 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RecurrentLayers" uuid = "78449bcf-6750-4b78-9e82-63d4a1ccdf8c" authors = ["Francesco Martinuzzi"] -version = "0.1.4" +version = "0.1.5" [deps] Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" diff --git a/src/RecurrentLayers.jl b/src/RecurrentLayers.jl index 7a8d67f..d0eb969 100644 --- a/src/RecurrentLayers.jl +++ b/src/RecurrentLayers.jl @@ -3,7 +3,7 @@ module RecurrentLayers using Flux import Flux: _size_check, _match_eltype, chunk, create_bias, zeros_like import Flux: glorot_uniform -import Flux: initialstates +import Flux: initialstates, scan abstract type AbstractRecurrentCell end abstract type AbstractDoubleRecurrentCell <: AbstractRecurrentCell end @@ -31,7 +31,7 @@ end function (rlayer::AbstractRecurrentLayer)(inp::AbstractVecOrMat) state = initialstates(rlayer) - return rcell(inp, state) + return rlayer(inp, state) end export MGUCell, LiGRUCell, IndRNNCell, RANCell, LightRUCell, RHNCell, diff --git a/src/fastrnn_cell.jl b/src/fastrnn_cell.jl index 1272f04..d414191 100644 --- a/src/fastrnn_cell.jl +++ b/src/fastrnn_cell.jl @@ -112,12 +112,7 @@ end function (fastrnn::FastRNN)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 - new_state = [] - for inp_t in eachslice(inp, dims=2) - state = fastrnn.cell(inp_t, state) - new_state = vcat(new_state, [state]) - end - return stack(new_state, dims=2) + return scan(fastrnn.cell, inp, state) end @@ -173,7 +168,7 @@ function FastGRNNCell((input_size, hidden_size)::Pair, activation=tanh_fast; zeta = randn(Float32) nu = randn(Float32) - return FastGRNNCell(Wi, Wh, b, alpha, beta, activation) + return FastGRNNCell(Wi, Wh, b, zeta, nu, activation) end function (fastgrnn::FastGRNNCell)(inp::AbstractVecOrMat, state) @@ -182,7 +177,7 @@ function (fastgrnn::FastGRNNCell)(inp::AbstractVecOrMat, state) # get variables Wi, Wh, b = fastgrnn.Wi, fastgrnn.Wh, fastgrnn.bias - alpha, beta = fastgrnn.alpha, fastgrnn.beta + zeta, nu = fastgrnn.zeta, fastgrnn.nu bh, bz = chunk(b, 2) partial_gate = Wi * inp .+ Wh * state @@ -240,10 +235,5 @@ end function (fastgrnn::FastGRNN)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 - new_state = [] - for inp_t in eachslice(inp, dims=2) - state = fastgrnn.cell(inp_t, state) - new_state = vcat(new_state, [state]) - end - return stack(new_state, dims=2) + return scan(fastgrnn.call, inp, state) end \ No newline at end of file diff --git a/src/indrnn_cell.jl b/src/indrnn_cell.jl index 36f31f7..c9a7228 100644 --- a/src/indrnn_cell.jl +++ b/src/indrnn_cell.jl @@ -92,10 +92,5 @@ end function (indrnn::IndRNN)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 - new_state = [] - for inp_t in eachslice(inp, dims=2) - state = indrnn.cell(inp_t, state) - new_state = vcat(new_state, [state]) - end - return stack(new_state, dims=2) + return scan(indrnn.cell, inp, state) end \ No newline at end of file diff --git a/src/lightru_cell.jl b/src/lightru_cell.jl index 9824d8b..87362d4 100644 --- a/src/lightru_cell.jl +++ b/src/lightru_cell.jl @@ -101,10 +101,5 @@ end function (lightru::LightRU)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 - new_state = [] - for inp_t in eachslice(inp, dims=2) - state = lightru.cell(inp_t, state) - new_state = vcat(new_state, [state]) - end - return stack(new_state, dims=2) + return scan(lightru.cell, inp, state) end diff --git a/src/ligru_cell.jl b/src/ligru_cell.jl index 88cd2a6..d622c14 100644 --- a/src/ligru_cell.jl +++ b/src/ligru_cell.jl @@ -101,12 +101,7 @@ end function (ligru::LiGRU)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 - new_state = [] - for inp_t in eachslice(inp, dims=2) - state = ligru.cell(inp_t, state) - new_state = vcat(new_state, [state]) - end - return stack(new_state, dims=2) + return scan(ligru.cell, inp, state) end diff --git a/src/mgu_cell.jl b/src/mgu_cell.jl index 483ad7f..b6ddd00 100644 --- a/src/mgu_cell.jl +++ b/src/mgu_cell.jl @@ -100,10 +100,5 @@ end function (mgu::MGU)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 - new_state = [] - for inp_t in eachslice(inp, dims=2) - state = mgu.cell(inp_t, state) - new_state = vcat(new_state, [state]) - end - return stack(new_state, dims=2) + return scan(mgu.cell, inp, state) end diff --git a/src/mut_cell.jl b/src/mut_cell.jl index 2d3a643..dec5a8b 100644 --- a/src/mut_cell.jl +++ b/src/mut_cell.jl @@ -104,12 +104,7 @@ end function (mut::MUT1)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 - new_state = [] - for inp_t in eachslice(inp, dims=2) - state = mut.cell(inp_t, state) - new_state = vcat(new_state, [state]) - end - return stack(new_state, dims=2) + return scan(mut.cell, inp, state) end @@ -218,12 +213,7 @@ end function (mut::MUT2)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 - new_state = [] - for inp_t in eachslice(inp, dims=2) - state = mut.cell(inp_t, state) - new_state = vcat(new_state, [state]) - end - return stack(new_state, dims=2) + return scan(mut.cell, inp, state) end @@ -330,10 +320,5 @@ end function (mut::MUT3)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 - new_state = [] - for inp_t in eachslice(inp, dims=2) - state = mut.cell(inp_t, state) - new_state = vcat(new_state, [state]) - end - return stack(new_state, dims=2) + return scan(mut.cell, inp, state) end \ No newline at end of file diff --git a/src/nas_cell.jl b/src/nas_cell.jl index 116b509..c8ae7dc 100644 --- a/src/nas_cell.jl +++ b/src/nas_cell.jl @@ -184,14 +184,7 @@ function NAS((input_size, hidden_size)::Pair; kwargs...) return NAS(cell) end -function (nas::NAS)(inp, (state, c_state)) +function (nas::NAS)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 - new_state = [] - new_cstate = [] - for inp_t in eachslice(inp, dims=2) - state, c_state = nas.cell(inp_t, (state, c_state)) - new_state = vcat(new_state, [state]) - new_cstate = vcat(new_cstate, [c_state]) - end - return stack(new_state, dims=2), stack(new_cstate, dims=2) + return scan(nas.cell, inp, state) end diff --git a/src/peepholelstm_cell.jl b/src/peepholelstm_cell.jl index d7e18f8..34a2ba4 100644 --- a/src/peepholelstm_cell.jl +++ b/src/peepholelstm_cell.jl @@ -114,14 +114,7 @@ function PeepholeLSTM((input_size, hidden_size)::Pair; kwargs...) return PeepholeLSTM(cell) end -function (lstm::PeepholeLSTM)(inp, (state, c_state)) +function (lstm::PeepholeLSTM)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 - new_state = [] - new_cstate = [] - for inp_t in eachslice(inp, dims=2) - state, c_state = nas.cell(inp_t, (state, c_state)) - new_state = vcat(new_state, [state]) - new_cstate = vcat(new_cstate, [c_state]) - end - return stack(new_state, dims=2), stack(new_cstate, dims=2) + return scan(lstm.cell, inp, state) end diff --git a/src/ran_cell.jl b/src/ran_cell.jl index aab0176..f54dc1c 100644 --- a/src/ran_cell.jl +++ b/src/ran_cell.jl @@ -135,15 +135,8 @@ function RAN((input_size, hidden_size)::Pair; kwargs...) return RAN(cell) end -function (ran::RAN)(inp, (state, c_state)) +function (ran::RAN)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 - new_state = [] - new_cstate = [] - for inp_t in eachslice(inp, dims=2) - state, c_state = ran.cell(inp_t, (state, c_state)) - new_state = vcat(new_state, [state]) - new_cstate = vcat(new_cstate, [c_state]) - end - return stack(new_state, dims=2), stack(new_cstate, dims=2) + return scan(ran.cell, inp, state) end diff --git a/src/scrn_cell.jl b/src/scrn_cell.jl index 5a929a6..bc099e3 100644 --- a/src/scrn_cell.jl +++ b/src/scrn_cell.jl @@ -117,18 +117,8 @@ function SCRN((input_size, hidden_size)::Pair; kwargs...) cell = SCRNCell(input_size => hidden_size; kwargs...) return SCRN(cell) end - -function (scrn::SCRN)(inp) - state = zeros_like(inp, size(scrn.cell.Wh, 2)) - return scrn(inp, state) -end function (scrn::SCRN)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 - new_state = [] - for inp_t in eachslice(inp, dims=2) - state = scrn.cell(inp_t, state) - new_state = vcat(new_state, [state]) - end - return stack(new_state, dims=2) + return scan(scrn.cell, inp, state) end \ No newline at end of file