From 0aef30c0345c30a2ff56b182723e5381339abc44 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Thu, 12 Dec 2024 20:22:22 +0100 Subject: [PATCH 1/3] start of initialstates --- src/RecurrentLayers.jl | 1 + src/fastrnn_cell.jl | 16 ++++++++++++---- src/indrnn_cell.jl | 12 ++++++++---- src/lightru_cell.jl | 30 +++++++++++++++++------------- src/ligru_cell.jl | 8 ++++++-- src/mgu_cell.jl | 8 ++++++-- src/mut_cell.jl | 22 +++++++++++++++++----- src/sru_cell.jl | 2 +- 8 files changed, 68 insertions(+), 31 deletions(-) diff --git a/src/RecurrentLayers.jl b/src/RecurrentLayers.jl index a527e60..d131531 100644 --- a/src/RecurrentLayers.jl +++ b/src/RecurrentLayers.jl @@ -3,6 +3,7 @@ module RecurrentLayers using Flux import Flux: _size_check, _match_eltype, chunk, create_bias, zeros_like import Flux: glorot_uniform +import Flux: initialstates export MGUCell, LiGRUCell, IndRNNCell, RANCell, LightRUCell, RHNCell, RHNCellUnit, NASCell, MUT1Cell, MUT2Cell, MUT3Cell, SCRNCell, PeepholeLSTMCell, diff --git a/src/fastrnn_cell.jl b/src/fastrnn_cell.jl index be40cb4..0e4acff 100644 --- a/src/fastrnn_cell.jl +++ b/src/fastrnn_cell.jl @@ -10,6 +10,8 @@ end Flux.@layer FastRNNCell +initialstates(fastrnn::FastRNNCell) = zeros_like(fastrnn.Wh, size(fastrnn.Wh, 2)) + @doc raw""" FastRNNCell((input_size => hidden_size), [activation]; init_kernel = glorot_uniform, @@ -54,7 +56,7 @@ function FastRNNCell((input_size, hidden_size)::Pair, activation=tanh_fast; end function (fastrnn::FastRNNCell)(inp::AbstractVecOrMat) - state = zeros_like(inp, size(fastrnn.Wh, 2)) + state = initialstates(fastrnn) return fastrnn(inp, state) end @@ -83,6 +85,8 @@ end Flux.@layer :expand FastRNN +initialstates(fastrnn::FastRNN) = initialstates(fastrnn.cell) + @doc raw""" FastRNN((input_size => hidden_size), [activation]; kwargs...) @@ -116,7 +120,7 @@ function FastRNN((input_size, hidden_size)::Pair, activation = tanh_fast; end function (fastrnn::FastRNN)(inp) - state = zeros_like(inp, size(fastrnn.cell.Wh, 2)) + state = initialstates(fastrnn) return fastrnn(inp, state) end @@ -142,6 +146,8 @@ end Flux.@layer FastGRNNCell +initialstates(fastgrnn::FastGRNN) = zeros_like(fastgrnn.Wh, size(fastgrnn.Wh, 2)) + @doc raw""" FastGRNNCell((input_size => hidden_size), [activation]; init_kernel = glorot_uniform, @@ -187,7 +193,7 @@ function FastGRNNCell((input_size, hidden_size)::Pair, activation=tanh_fast; end function (fastgrnn::FastGRNNCell)(inp::AbstractVecOrMat) - state = zeros_like(inp, size(fastgrnn.Wh, 2)) + state = initialstates(fastgrnn) return fastgrnn(inp, state) end @@ -220,6 +226,8 @@ end Flux.@layer :expand FastGRNN +initialstates(fastgrnn::FastGRNN) = initialstates(fastgrnn.cell) + @doc raw""" FastGRNN((input_size => hidden_size), [activation]; kwargs...) @@ -254,7 +262,7 @@ function FastGRNN((input_size, hidden_size)::Pair, activation = tanh_fast; end function (fastgrnn::FastGRNN)(inp) - state = zeros_like(inp, size(fastgrnn.cell.Wh, 2)) + state = initialstates(fastgrnn) return fastgrnn(inp, state) end diff --git a/src/indrnn_cell.jl b/src/indrnn_cell.jl index 00400eb..d16492b 100644 --- a/src/indrnn_cell.jl +++ b/src/indrnn_cell.jl @@ -8,6 +8,8 @@ end Flux.@layer IndRNNCell +initialstates(indrnn::FastRNNCell) = zeros_like(indrnn.Wi, size(indrnn.Wi, 2)) + @doc raw""" IndRNNCell((input_size => hidden_size)::Pair, σ=relu; init_kernel = glorot_uniform, @@ -46,9 +48,9 @@ function IndRNNCell((input_size, hidden_size)::Pair, σ=relu; return IndRNNCell(σ, Wi, u, b) end -function (indrnn::IndRNNCell)(x::AbstractVecOrMat) - state = zeros_like(x, size(indrnn.u, 1)) - return indrnn(x, state) +function (indrnn::IndRNNCell)(inp::AbstractVecOrMat) + state = initialstates(indrnn) + return indrnn(inp, state) end function (indrnn::IndRNNCell)(inp::AbstractVecOrMat, state::AbstractVecOrMat) @@ -70,6 +72,8 @@ end Flux.@layer :expand IndRNN +initialstates(indrnn::IndRNN) = initialstates(indrnn.cell) + @doc raw""" IndRNN((input_size, hidden_size)::Pair, σ = tanh, σ=relu; kwargs...) @@ -96,7 +100,7 @@ function IndRNN((input_size, hidden_size)::Pair, σ = tanh; kwargs...) end function (indrnn::IndRNN)(inp) - state = zeros_like(inp, size(indrnn.cell.u, 1)) + state = initialstates(indrnn) return indrnn(inp, state) end diff --git a/src/lightru_cell.jl b/src/lightru_cell.jl index 1f2f081..fc7b319 100644 --- a/src/lightru_cell.jl +++ b/src/lightru_cell.jl @@ -7,6 +7,8 @@ end Flux.@layer LightRUCell +initialstates(lightru::LightRUCell) = zeros_like(lightru.Wh, size(lightru.Wh, 2)) + @doc raw""" LightRUCell((input_size => hidden_size)::Pair; init_kernel = glorot_uniform, @@ -47,14 +49,14 @@ function LightRUCell((input_size, hidden_size)::Pair; return LightRUCell(Wi, Wh, b) end -function (lru::LightRUCell)(inp::AbstractVecOrMat) - state = zeros_like(inp, size(lru.Wh, 2)) - return lru(inp, state) +function (lightru::LightRUCell)(inp::AbstractVecOrMat) + state = initialstates(lightru) + return lightru(inp, state) end -function (lru::LightRUCell)(inp::AbstractVecOrMat, state) - _size_check(lru, inp, 1 => size(lru.Wi,2)) - Wi, Wh, b = lru.Wi, lru.Wh, lru.bias +function (lightru::LightRUCell)(inp::AbstractVecOrMat, state) + _size_check(lightru, inp, 1 => size(lightru.Wi,2)) + Wi, Wh, b = lightru.Wi, lightru.Wh, lightru.bias #split gxs = chunk(Wi * inp, 2, dims=1) @@ -66,8 +68,8 @@ function (lru::LightRUCell)(inp::AbstractVecOrMat, state) return new_state end -Base.show(io::IO, lru::LightRUCell) = - print(io, "LightRUCell(", size(lru.Wi, 2), " => ", size(lru.Wi, 1)÷2, ")") +Base.show(io::IO, lightru::LightRUCell) = + print(io, "LightRUCell(", size(lightru.Wi, 2), " => ", size(lightru.Wi, 1)÷2, ")") @@ -77,6 +79,8 @@ end Flux.@layer :expand LightRU +initialstates(lightru::LightRU) = initialstates(lightru.cell) + @doc raw""" LightRU((input_size => hidden_size)::Pair; kwargs...) @@ -104,16 +108,16 @@ function LightRU((input_size, hidden_size)::Pair; kwargs...) return LightRU(cell) end -function (lru::LightRU)(inp) - state = zeros_like(inp, size(lru.cell.Wh, 2)) - return lru(inp, state) +function (lightru::LightRU)(inp) + state = initialstates(lightru) + return lightru(inp, state) end -function (lru::LightRU)(inp, state) +function (lightru::LightRU)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 new_state = [] for inp_t in eachslice(inp, dims=2) - state = lru.cell(inp_t, state) + state = lightru.cell(inp_t, state) new_state = vcat(new_state, [state]) end return stack(new_state, dims=2) diff --git a/src/ligru_cell.jl b/src/ligru_cell.jl index b09e842..ef5b780 100644 --- a/src/ligru_cell.jl +++ b/src/ligru_cell.jl @@ -7,6 +7,8 @@ end Flux.@layer LiGRUCell +initialstates(ligru::LiGRUCell) = zeros_like(ligru.Wh, size(ligru.Wh, 2)) + @doc raw""" LiGRUCell((input_size => hidden_size)::Pair; init_kernel = glorot_uniform, @@ -51,7 +53,7 @@ function LiGRUCell((input_size, hidden_size)::Pair; end function (ligru::LiGRUCell)(inp::AbstractVecOrMat) - state = zeros_like(inp, size(ligru.Wh, 2)) + state = initialstates(ligru) return ligru(inp, state) end @@ -75,6 +77,8 @@ end Flux.@layer :expand LiGRU +initialstates(ligru::LiGRU) = initialstates(ligru.cell) + @doc raw""" LiGRU((input_size => hidden_size)::Pair; kwargs...) @@ -105,7 +109,7 @@ function LiGRU((input_size, hidden_size)::Pair; kwargs...) end function (ligru::LiGRU)(inp) - state = zeros_like(inp, size(ligru.cell.Wh, 2)) + state = initialstates(ligru) return ligru(inp, state) end diff --git a/src/mgu_cell.jl b/src/mgu_cell.jl index 3d81106..1463979 100644 --- a/src/mgu_cell.jl +++ b/src/mgu_cell.jl @@ -7,6 +7,8 @@ end Flux.@layer MGUCell +initialstates(mgu::MGUCell) = zeros_like(mgu.Wh, size(mgu.Wh, 2)) + @doc raw""" MGUCell((input_size => hidden_size)::Pair; init_kernel = glorot_uniform, @@ -49,7 +51,7 @@ function MGUCell((input_size, hidden_size)::Pair; end function (mgu::MGUCell)(inp::AbstractVecOrMat) - state = zeros_like(inp, size(mgu.Wh, 2)) + state = initialstates(mgu) return mgu(inp, state) end @@ -76,6 +78,8 @@ end Flux.@layer :expand MGU +initialstates(mgu::MGU) = initialstates(mgu.cell) + @doc raw""" MGU((input_size => hidden_size)::Pair; kwargs...) @@ -104,7 +108,7 @@ function MGU((input_size, hidden_size)::Pair; kwargs...) end function (mgu::MGU)(inp) - state = zeros_like(inp, size(mgu.cell.Wh, 2)) + state = initialstates(mgu) return mgu(inp, state) end diff --git a/src/mut_cell.jl b/src/mut_cell.jl index 835ed3a..8c412c7 100644 --- a/src/mut_cell.jl +++ b/src/mut_cell.jl @@ -7,6 +7,8 @@ end Flux.@layer MUT1Cell +initialstates(mut::MUT1Cell) = zeros_like(mut.Wh, size(mut.Wh, 2)) + @doc raw""" MUT1Cell((input_size => hidden_size); init_kernel = glorot_uniform, @@ -50,7 +52,7 @@ function MUT1Cell((input_size, hidden_size)::Pair; end function (mut::MUT1Cell)(inp::AbstractVecOrMat) - state = zeros_like(inp, size(mut.Wh, 2)) + state = initialstates(mut) return mut(inp, state) end @@ -79,6 +81,8 @@ end Flux.@layer :expand MUT1 +initialstates(mut::MUT1) = initialstates(mut.cell) + @doc raw""" MUT1((input_size => hidden_size); kwargs...) @@ -108,7 +112,7 @@ function MUT1((input_size, hidden_size)::Pair; kwargs...) end function (mut::MUT1)(inp) - state = zeros_like(inp, size(mut.cell.Wh, 2)) + state = initialstates(mut) return mut(inp, state) end @@ -132,6 +136,8 @@ end Flux.@layer MUT2Cell +initialstates(mut::MUT2Cell) = zeros_like(mut.Wh, size(mut.Wh, 2)) + @doc raw""" MUT2Cell((input_size => hidden_size); init_kernel = glorot_uniform, @@ -175,7 +181,7 @@ function MUT2Cell((input_size, hidden_size)::Pair; end function (mut::MUT2Cell)(inp::AbstractVecOrMat) - state = zeros_like(inp, size(mut.Wh, 2)) + state = initialstates(mut) return mut(inp, state) end @@ -204,6 +210,8 @@ end Flux.@layer :expand MUT2 +initialstates(mut::MUT2) = initialstates(mut.cell) + @doc raw""" MUT2Cell((input_size => hidden_size); kwargs...) @@ -233,7 +241,7 @@ function MUT2((input_size, hidden_size)::Pair; kwargs...) end function (mut::MUT2)(inp) - state = zeros_like(inp, size(mut.cell.Wh, 2)) + state = initialstates(mut) return mut(inp, state) end @@ -256,6 +264,8 @@ end Flux.@layer MUT3Cell +initialstates(mut::MUT3Cell) = zeros_like(mut.Wh, size(mut.Wh, 2)) + @doc raw""" MUT3Cell((input_size => hidden_size); init_kernel = glorot_uniform, @@ -299,7 +309,7 @@ function MUT3Cell((input_size, hidden_size)::Pair; end function (mut::MUT3Cell)(inp::AbstractVecOrMat) - state = zeros_like(inp, size(mut.Wh, 2)) + state = initialstates(mut) return mut(inp, state) end @@ -326,6 +336,8 @@ end Flux.@layer :expand MUT3 +initialstates(mut::MUT3) = initialstates(mut.cell) + @doc raw""" MUT3((input_size => hidden_size); kwargs...) diff --git a/src/sru_cell.jl b/src/sru_cell.jl index 57c65e6..1c17014 100644 --- a/src/sru_cell.jl +++ b/src/sru_cell.jl @@ -46,4 +46,4 @@ function (sru::SRUCell)(inp::AbstractVecOrMat, (state, c_state)) end Base.show(io::IO, sru::SRUCell) = - print(io, "SRUCell(", size(sru.Wi, 2), " => ", size(sru.Wi, 1)÷2, ")") \ No newline at end of file + print(io, "SRUCell(", size(sru.Wi, 2), " => ", size(sru.Wi, 1)÷2, ")") From ce8bcf0d7f2ad05518a30bb282b3675dc9fa4453 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Fri, 13 Dec 2024 16:57:54 +0100 Subject: [PATCH 2/3] abstraction and initstates --- src/RecurrentLayers.jl | 29 +++++++++++++++++++++ src/fastrnn_cell.jl | 36 +++----------------------- src/indrnn_cell.jl | 26 +++++-------------- src/lightru_cell.jl | 18 ++----------- src/ligru_cell.jl | 18 ++----------- src/mgu_cell.jl | 18 ++----------- src/mut_cell.jl | 55 +++++----------------------------------- src/nas_cell.jl | 16 ++---------- src/peepholelstm_cell.jl | 16 ++---------- src/ran_cell.jl | 17 ++----------- src/scrn_cell.jl | 4 +-- 11 files changed, 59 insertions(+), 194 deletions(-) diff --git a/src/RecurrentLayers.jl b/src/RecurrentLayers.jl index d131531..6c0e76c 100644 --- a/src/RecurrentLayers.jl +++ b/src/RecurrentLayers.jl @@ -5,6 +5,35 @@ import Flux: _size_check, _match_eltype, chunk, create_bias, zeros_like import Flux: glorot_uniform import Flux: initialstates +abstract type AbstractRecurrentCell end +abstract type AbstractDoubleRecurrentCell <: AbstractRecurrentCell end + +function initialstates(rcell::AbstractRecurrentCell) + return zeros_like(rcell.Wh, size(rcell.Wh, 2)) +end + +function initialstates(rcell::AbstractDoubleRecurrentCell) + state = zeros_like(rcell.Wh, size(rcell.Wh, 2)) + second_state = zeros_like(rcell.Wh, size(rcell.Wh, 2)) + return state, second_state +end + +function (rcell::AbstractRecurrentCell)(inp) + state = initialstates(rcell) + return rcell(inp, state) +end + +abstract type AbstractRecurrentLayer end + +function initialstates(rlayer::AbstractRecurrentLayer) + return initialstates(rlayer.cell) +end + +function (rlayer::AbstractRecurrentLayer)(inp) + state = initialstates(rlayer) + return rcell(inp, state) +end + export MGUCell, LiGRUCell, IndRNNCell, RANCell, LightRUCell, RHNCell, RHNCellUnit, NASCell, MUT1Cell, MUT2Cell, MUT3Cell, SCRNCell, PeepholeLSTMCell, FastRNNCell, FastGRNNCell diff --git a/src/fastrnn_cell.jl b/src/fastrnn_cell.jl index 0e4acff..1272f04 100644 --- a/src/fastrnn_cell.jl +++ b/src/fastrnn_cell.jl @@ -1,5 +1,5 @@ #https://arxiv.org/abs/1901.02358 -struct FastRNNCell{I, H, V, A, B, F} +struct FastRNNCell{I, H, V, A, B, F} <: AbstractRecurrentCell Wi::I Wh::H bias::V @@ -10,8 +10,6 @@ end Flux.@layer FastRNNCell -initialstates(fastrnn::FastRNNCell) = zeros_like(fastrnn.Wh, size(fastrnn.Wh, 2)) - @doc raw""" FastRNNCell((input_size => hidden_size), [activation]; init_kernel = glorot_uniform, @@ -55,11 +53,6 @@ function FastRNNCell((input_size, hidden_size)::Pair, activation=tanh_fast; return FastRNNCell(Wi, Wh, b, alpha, beta, activation) end -function (fastrnn::FastRNNCell)(inp::AbstractVecOrMat) - state = initialstates(fastrnn) - return fastrnn(inp, state) -end - function (fastrnn::FastRNNCell)(inp::AbstractVecOrMat, state) #checks _size_check(fastrnn, inp, 1 => size(fastrnn.Wi,2)) @@ -79,14 +72,12 @@ Base.show(io::IO, fastrnn::FastRNNCell) = print(io, "FastRNNCell(", size(fastrnn.Wi, 2), " => ", size(fastrnn.Wi, 1) ÷ 2, ")") -struct FastRNN{M} +struct FastRNN{M} <: AbstractRecurrentLayer cell::M end Flux.@layer :expand FastRNN -initialstates(fastrnn::FastRNN) = initialstates(fastrnn.cell) - @doc raw""" FastRNN((input_size => hidden_size), [activation]; kwargs...) @@ -118,11 +109,6 @@ function FastRNN((input_size, hidden_size)::Pair, activation = tanh_fast; cell = FastRNNCell(input_size => hidden_size, activation; kwargs...) return FastRNN(cell) end - -function (fastrnn::FastRNN)(inp) - state = initialstates(fastrnn) - return fastrnn(inp, state) -end function (fastrnn::FastRNN)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 @@ -135,7 +121,7 @@ function (fastrnn::FastRNN)(inp, state) end -struct FastGRNNCell{I, H, V, A, B, F} +struct FastGRNNCell{I, H, V, A, B, F} <: AbstractRecurrentCell Wi::I Wh::H bias::V @@ -146,8 +132,6 @@ end Flux.@layer FastGRNNCell -initialstates(fastgrnn::FastGRNN) = zeros_like(fastgrnn.Wh, size(fastgrnn.Wh, 2)) - @doc raw""" FastGRNNCell((input_size => hidden_size), [activation]; init_kernel = glorot_uniform, @@ -192,11 +176,6 @@ function FastGRNNCell((input_size, hidden_size)::Pair, activation=tanh_fast; return FastGRNNCell(Wi, Wh, b, alpha, beta, activation) end -function (fastgrnn::FastGRNNCell)(inp::AbstractVecOrMat) - state = initialstates(fastgrnn) - return fastgrnn(inp, state) -end - function (fastgrnn::FastGRNNCell)(inp::AbstractVecOrMat, state) #checks _size_check(fastgrnn, inp, 1 => size(fastgrnn.Wi,2)) @@ -220,14 +199,12 @@ Base.show(io::IO, fastgrnn::FastGRNNCell) = print(io, "FastGRNNCell(", size(fastgrnn.Wi, 2), " => ", size(fastgrnn.Wi, 1) ÷ 2, ")") -struct FastGRNN{M} +struct FastGRNN{M} <: AbstractRecurrentLayer cell::M end Flux.@layer :expand FastGRNN -initialstates(fastgrnn::FastGRNN) = initialstates(fastgrnn.cell) - @doc raw""" FastGRNN((input_size => hidden_size), [activation]; kwargs...) @@ -260,11 +237,6 @@ function FastGRNN((input_size, hidden_size)::Pair, activation = tanh_fast; cell = FastGRNNCell(input_size => hidden_size, activation; kwargs...) return FastGRNN(cell) end - -function (fastgrnn::FastGRNN)(inp) - state = initialstates(fastgrnn) - return fastgrnn(inp, state) -end function (fastgrnn::FastGRNN)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 diff --git a/src/indrnn_cell.jl b/src/indrnn_cell.jl index d16492b..36f31f7 100644 --- a/src/indrnn_cell.jl +++ b/src/indrnn_cell.jl @@ -1,15 +1,13 @@ #https://arxiv.org/pdf/1803.04831 -struct IndRNNCell{F,I,H,V} +struct IndRNNCell{F,I,H,V} <: AbstractRecurrentCell σ::F Wi::I - u::H + Wh::H b::V end Flux.@layer IndRNNCell -initialstates(indrnn::FastRNNCell) = zeros_like(indrnn.Wi, size(indrnn.Wi, 2)) - @doc raw""" IndRNNCell((input_size => hidden_size)::Pair, σ=relu; init_kernel = glorot_uniform, @@ -43,20 +41,15 @@ function IndRNNCell((input_size, hidden_size)::Pair, σ=relu; init_recurrent_kernel = glorot_uniform, bias = true) Wi = init_kernel(hidden_size, input_size) - u = init_recurrent_kernel(hidden_size) + Wh = init_recurrent_kernel(hidden_size) b = create_bias(Wi, bias, size(Wi, 1)) - return IndRNNCell(σ, Wi, u, b) -end - -function (indrnn::IndRNNCell)(inp::AbstractVecOrMat) - state = initialstates(indrnn) - return indrnn(inp, state) + return IndRNNCell(σ, Wi, Wh, b) end function (indrnn::IndRNNCell)(inp::AbstractVecOrMat, state::AbstractVecOrMat) _size_check(indrnn, inp, 1 => size(indrnn.Wi, 2)) σ = NNlib.fast_act(indrnn.σ, inp) - state = σ.(indrnn.Wi*inp .+ indrnn.u .* state .+ indrnn.b) + state = σ.(indrnn.Wi*inp .+ indrnn.Wh .* state .+ indrnn.b) return state end @@ -66,14 +59,12 @@ function Base.show(io::IO, indrnn::IndRNNCell) print(io, ")") end -struct IndRNN{M} +struct IndRNN{M} <: AbstractRecurrentLayer cell::M end Flux.@layer :expand IndRNN -initialstates(indrnn::IndRNN) = initialstates(indrnn.cell) - @doc raw""" IndRNN((input_size, hidden_size)::Pair, σ = tanh, σ=relu; kwargs...) @@ -99,11 +90,6 @@ function IndRNN((input_size, hidden_size)::Pair, σ = tanh; kwargs...) return IndRNN(cell) end -function (indrnn::IndRNN)(inp) - state = initialstates(indrnn) - return indrnn(inp, state) -end - function (indrnn::IndRNN)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 new_state = [] diff --git a/src/lightru_cell.jl b/src/lightru_cell.jl index fc7b319..9824d8b 100644 --- a/src/lightru_cell.jl +++ b/src/lightru_cell.jl @@ -1,5 +1,5 @@ #https://www.mdpi.com/2079-9292/13/16/3204 -struct LightRUCell{I,H,V} +struct LightRUCell{I,H,V} <: AbstractRecurrentCell Wi::I Wh::H bias::V @@ -7,8 +7,6 @@ end Flux.@layer LightRUCell -initialstates(lightru::LightRUCell) = zeros_like(lightru.Wh, size(lightru.Wh, 2)) - @doc raw""" LightRUCell((input_size => hidden_size)::Pair; init_kernel = glorot_uniform, @@ -49,11 +47,6 @@ function LightRUCell((input_size, hidden_size)::Pair; return LightRUCell(Wi, Wh, b) end -function (lightru::LightRUCell)(inp::AbstractVecOrMat) - state = initialstates(lightru) - return lightru(inp, state) -end - function (lightru::LightRUCell)(inp::AbstractVecOrMat, state) _size_check(lightru, inp, 1 => size(lightru.Wi,2)) Wi, Wh, b = lightru.Wi, lightru.Wh, lightru.bias @@ -73,14 +66,12 @@ Base.show(io::IO, lightru::LightRUCell) = -struct LightRU{M} +struct LightRU{M} <: AbstractRecurrentLayer cell::M end Flux.@layer :expand LightRU -initialstates(lightru::LightRU) = initialstates(lightru.cell) - @doc raw""" LightRU((input_size => hidden_size)::Pair; kwargs...) @@ -108,11 +99,6 @@ function LightRU((input_size, hidden_size)::Pair; kwargs...) return LightRU(cell) end -function (lightru::LightRU)(inp) - state = initialstates(lightru) - return lightru(inp, state) -end - function (lightru::LightRU)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 new_state = [] diff --git a/src/ligru_cell.jl b/src/ligru_cell.jl index ef5b780..88cd2a6 100644 --- a/src/ligru_cell.jl +++ b/src/ligru_cell.jl @@ -1,5 +1,5 @@ #https://arxiv.org/pdf/1803.10225 -struct LiGRUCell{I, H, V} +struct LiGRUCell{I, H, V} <: AbstractRecurrentCell Wi::I Wh::H bias::V @@ -7,8 +7,6 @@ end Flux.@layer LiGRUCell -initialstates(ligru::LiGRUCell) = zeros_like(ligru.Wh, size(ligru.Wh, 2)) - @doc raw""" LiGRUCell((input_size => hidden_size)::Pair; init_kernel = glorot_uniform, @@ -52,11 +50,6 @@ function LiGRUCell((input_size, hidden_size)::Pair; return LiGRUCell(Wi, Wh, b) end -function (ligru::LiGRUCell)(inp::AbstractVecOrMat) - state = initialstates(ligru) - return ligru(inp, state) -end - function (ligru::LiGRUCell)(inp::AbstractVecOrMat, state) _size_check(ligru, inp, 1 => size(ligru.Wi,2)) Wi, Wh, b = ligru.Wi, ligru.Wh, ligru.bias @@ -71,14 +64,12 @@ function (ligru::LiGRUCell)(inp::AbstractVecOrMat, state) end -struct LiGRU{M} +struct LiGRU{M} <: AbstractRecurrentLayer cell::M end Flux.@layer :expand LiGRU -initialstates(ligru::LiGRU) = initialstates(ligru.cell) - @doc raw""" LiGRU((input_size => hidden_size)::Pair; kwargs...) @@ -108,11 +99,6 @@ function LiGRU((input_size, hidden_size)::Pair; kwargs...) return LiGRU(cell) end -function (ligru::LiGRU)(inp) - state = initialstates(ligru) - return ligru(inp, state) -end - function (ligru::LiGRU)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 new_state = [] diff --git a/src/mgu_cell.jl b/src/mgu_cell.jl index 1463979..483ad7f 100644 --- a/src/mgu_cell.jl +++ b/src/mgu_cell.jl @@ -1,5 +1,5 @@ #https://arxiv.org/pdf/1603.09420 -struct MGUCell{I, H, V} +struct MGUCell{I, H, V} <: AbstractRecurrentCell Wi::I Wh::H bias::V @@ -7,8 +7,6 @@ end Flux.@layer MGUCell -initialstates(mgu::MGUCell) = zeros_like(mgu.Wh, size(mgu.Wh, 2)) - @doc raw""" MGUCell((input_size => hidden_size)::Pair; init_kernel = glorot_uniform, @@ -50,11 +48,6 @@ function MGUCell((input_size, hidden_size)::Pair; return MGUCell(Wi, Wh, b) end -function (mgu::MGUCell)(inp::AbstractVecOrMat) - state = initialstates(mgu) - return mgu(inp, state) -end - function (mgu::MGUCell)(inp::AbstractVecOrMat, state) _size_check(mgu, inp, 1 => size(mgu.Wi,2)) Wi, Wh, b = mgu.Wi, mgu.Wh, mgu.bias @@ -72,14 +65,12 @@ Base.show(io::IO, mgu::MGUCell) = print(io, "MGUCell(", size(mgu.Wi, 2), " => ", size(mgu.Wi, 1) ÷ 2, ")") -struct MGU{M} +struct MGU{M} <: AbstractRecurrentLayer cell::M end Flux.@layer :expand MGU -initialstates(mgu::MGU) = initialstates(mgu.cell) - @doc raw""" MGU((input_size => hidden_size)::Pair; kwargs...) @@ -106,11 +97,6 @@ function MGU((input_size, hidden_size)::Pair; kwargs...) cell = MGUCell(input_size => hidden_size; kwargs...) return MGU(cell) end - -function (mgu::MGU)(inp) - state = initialstates(mgu) - return mgu(inp, state) -end function (mgu::MGU)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 diff --git a/src/mut_cell.jl b/src/mut_cell.jl index 8c412c7..2d3a643 100644 --- a/src/mut_cell.jl +++ b/src/mut_cell.jl @@ -1,5 +1,5 @@ #https://proceedings.mlr.press/v37/jozefowicz15.pdf -struct MUT1Cell{I, H, V} +struct MUT1Cell{I, H, V} <: AbstractRecurrentCell Wi::I Wh::H bias::V @@ -7,8 +7,6 @@ end Flux.@layer MUT1Cell -initialstates(mut::MUT1Cell) = zeros_like(mut.Wh, size(mut.Wh, 2)) - @doc raw""" MUT1Cell((input_size => hidden_size); init_kernel = glorot_uniform, @@ -51,11 +49,6 @@ function MUT1Cell((input_size, hidden_size)::Pair; return MUT1Cell(Wi, Wh, b) end -function (mut::MUT1Cell)(inp::AbstractVecOrMat) - state = initialstates(mut) - return mut(inp, state) -end - function (mut::MUT1Cell)(inp::AbstractVecOrMat, state) _size_check(mut, inp, 1 => size(mut.Wi,2)) Wi, Wh, b = mut.Wi, mut.Wh, mut.bias @@ -75,14 +68,12 @@ end Base.show(io::IO, mut::MUT1Cell) = print(io, "MUT1Cell(", size(mut.Wi, 2), " => ", size(mut.Wi, 1) ÷ 3, ")") -struct MUT1{M} +struct MUT1{M} <: AbstractRecurrentLayer cell::M end Flux.@layer :expand MUT1 -initialstates(mut::MUT1) = initialstates(mut.cell) - @doc raw""" MUT1((input_size => hidden_size); kwargs...) @@ -111,11 +102,6 @@ function MUT1((input_size, hidden_size)::Pair; kwargs...) return MUT1(cell) end -function (mut::MUT1)(inp) - state = initialstates(mut) - return mut(inp, state) -end - function (mut::MUT1)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 new_state = [] @@ -127,8 +113,7 @@ function (mut::MUT1)(inp, state) end - -struct MUT2Cell{I, H, V} +struct MUT2Cell{I, H, V} <: AbstractRecurrentCell Wi::I Wh::H bias::V @@ -136,8 +121,6 @@ end Flux.@layer MUT2Cell -initialstates(mut::MUT2Cell) = zeros_like(mut.Wh, size(mut.Wh, 2)) - @doc raw""" MUT2Cell((input_size => hidden_size); init_kernel = glorot_uniform, @@ -180,11 +163,6 @@ function MUT2Cell((input_size, hidden_size)::Pair; return MUT2Cell(Wi, Wh, b) end -function (mut::MUT2Cell)(inp::AbstractVecOrMat) - state = initialstates(mut) - return mut(inp, state) -end - function (mut::MUT2Cell)(inp::AbstractVecOrMat, state) _size_check(mut, inp, 1 => size(mut.Wi,2)) Wi, Wh, b = mut.Wi, mut.Wh, mut.bias @@ -204,14 +182,12 @@ Base.show(io::IO, mut::MUT2Cell) = print(io, "MUT2Cell(", size(mut.Wi, 2), " => ", size(mut.Wi, 1) ÷ 3, ")") -struct MUT2{M} +struct MUT2{M} <: AbstractRecurrentLayer cell::M end Flux.@layer :expand MUT2 -initialstates(mut::MUT2) = initialstates(mut.cell) - @doc raw""" MUT2Cell((input_size => hidden_size); kwargs...) @@ -239,11 +215,6 @@ function MUT2((input_size, hidden_size)::Pair; kwargs...) cell = MUT2Cell(input_size => hidden_size; kwargs...) return MUT2(cell) end - -function (mut::MUT2)(inp) - state = initialstates(mut) - return mut(inp, state) -end function (mut::MUT2)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 @@ -256,7 +227,7 @@ function (mut::MUT2)(inp, state) end -struct MUT3Cell{I, H, V} +struct MUT3Cell{I, H, V} <: AbstractRecurrentCell Wi::I Wh::H bias::V @@ -264,8 +235,6 @@ end Flux.@layer MUT3Cell -initialstates(mut::MUT3Cell) = zeros_like(mut.Wh, size(mut.Wh, 2)) - @doc raw""" MUT3Cell((input_size => hidden_size); init_kernel = glorot_uniform, @@ -308,11 +277,6 @@ function MUT3Cell((input_size, hidden_size)::Pair; return MUT3Cell(Wi, Wh, b) end -function (mut::MUT3Cell)(inp::AbstractVecOrMat) - state = initialstates(mut) - return mut(inp, state) -end - function (mut::MUT3Cell)(inp::AbstractVecOrMat, state) _size_check(mut, inp, 1 => size(mut.Wi,2)) Wi, Wh, b = mut.Wi, mut.Wh, mut.bias @@ -330,14 +294,12 @@ end Base.show(io::IO, mut::MUT3Cell) = print(io, "MUT3Cell(", size(mut.Wi, 2), " => ", size(mut.Wi, 1) ÷ 3, ")") -struct MUT3{M} +struct MUT3{M} <: AbstractRecurrentLayer cell::M end Flux.@layer :expand MUT3 -initialstates(mut::MUT3) = initialstates(mut.cell) - @doc raw""" MUT3((input_size => hidden_size); kwargs...) @@ -365,11 +327,6 @@ function MUT3((input_size, hidden_size)::Pair; kwargs...) cell = MUT3Cell(input_size => hidden_size; kwargs...) return MUT3(cell) end - -function (mut::MUT3)(inp) - state = zeros_like(inp, size(mut.cell.Wh, 2)) - return mut(inp, state) -end function (mut::MUT3)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 diff --git a/src/nas_cell.jl b/src/nas_cell.jl index b5705fb..116b509 100644 --- a/src/nas_cell.jl +++ b/src/nas_cell.jl @@ -23,7 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -struct NASCell{I,H,V} +struct NASCell{I,H,V} <: AbstractDoubleRecurrentCell Wi::I Wh::H bias::V @@ -90,12 +90,6 @@ function NASCell((input_size, hidden_size)::Pair; return NASCell(Wi, Wh, b) end -function (nas::NASCell)(inp::AbstractVecOrMat) - state = zeros_like(inp, size(nas.Wh, 2)) - c_state = zeros_like(state) - return nas(inp, (state, c_state)) -end - function (nas::NASCell)(inp::AbstractVecOrMat, (state, c_state)) _size_check(nas, inp, 1 => size(nas.Wi,2)) Wi, Wh, b = nas.Wi, nas.Wh, nas.bias @@ -136,7 +130,7 @@ Base.show(io::IO, nas::NASCell) = print(io, "NASCell(", size(nas.Wi, 2), " => ", size(nas.Wi, 1)÷8, ")") -struct NAS{M} +struct NAS{M} <: AbstractRecurrentLayer cell::M end @@ -190,12 +184,6 @@ function NAS((input_size, hidden_size)::Pair; kwargs...) return NAS(cell) end -function (nas::NAS)(inp) - state = zeros_like(inp, size(nas.cell.Wh, 2)) - c_state = zeros_like(state) - return nas(inp, (state, c_state)) -end - function (nas::NAS)(inp, (state, c_state)) @assert ndims(inp) == 2 || ndims(inp) == 3 new_state = [] diff --git a/src/peepholelstm_cell.jl b/src/peepholelstm_cell.jl index c675844..d7e18f8 100644 --- a/src/peepholelstm_cell.jl +++ b/src/peepholelstm_cell.jl @@ -1,5 +1,5 @@ #https://www.jmlr.org/papers/volume3/gers02a/gers02a.pdf -struct PeepholeLSTMCell{I, H, V} +struct PeepholeLSTMCell{I, H, V} <: AbstractDoubleRecurrentCell Wi::I Wh::H bias::V @@ -62,12 +62,6 @@ function PeepholeLSTMCell( return cell end -function (lstm::PeepholeLSTMCell)(inp::AbstractVecOrMat) - state = zeros_like(inp, size(lstm.Wh, 2)) - c_state = zeros_like(state) - return lstm(inp, (state, c_state)) -end - function (lstm::PeepholeLSTMCell)(inp::AbstractVecOrMat, (state, c_state)) _size_check(lstm, inp, 1 => size(lstm.Wi, 2)) @@ -84,7 +78,7 @@ Base.show(io::IO, lstm::PeepholeLSTMCell) = -struct PeepholeLSTM{M} +struct PeepholeLSTM{M} <: AbstractRecurrentLayer cell::M end @@ -120,12 +114,6 @@ function PeepholeLSTM((input_size, hidden_size)::Pair; kwargs...) return PeepholeLSTM(cell) end -function (lstm::PeepholeLSTM)(inp) - state = zeros_like(inp, size(lstm.cell.Wh, 2)) - c_state = zeros_like(state) - return lstm(inp, (state, c_state)) -end - function (lstm::PeepholeLSTM)(inp, (state, c_state)) @assert ndims(inp) == 2 || ndims(inp) == 3 new_state = [] diff --git a/src/ran_cell.jl b/src/ran_cell.jl index 6482cb5..aab0176 100644 --- a/src/ran_cell.jl +++ b/src/ran_cell.jl @@ -1,5 +1,5 @@ #https://arxiv.org/pdf/1705.07393 -struct RANCell{I,H,V} +struct RANCell{I,H,V} <: AbstractDoubleRecurrentCell Wi::I Wh::H bias::V @@ -7,7 +7,6 @@ end Flux.@layer RANCell - @doc raw""" RANCell((input_size => hidden_size)::Pair; init_kernel = glorot_uniform, @@ -76,12 +75,6 @@ function RANCell((input_size, hidden_size)::Pair; return RANCell(Wi, Wh, b) end -function (ran::RANCell)(inp::AbstractVecOrMat) - state = zeros_like(inp, size(ran.Wh, 2)) - c_state = zeros_like(state) - return ran(inp, (state, c_state)) -end - function (ran::RANCell)(inp::AbstractVecOrMat, (state, c_state)) _size_check(ran, inp, 1 => size(ran.Wi,2)) Wi, Wh, b = ran.Wi, ran.Wh, ran.bias @@ -102,7 +95,7 @@ Base.show(io::IO, ran::RANCell) = print(io, "RANCell(", size(ran.Wi, 2), " => ", size(ran.Wi, 1)÷3, ")") -struct RAN{M} +struct RAN{M} <: AbstractRecurrentLayer cell::M end @@ -142,12 +135,6 @@ function RAN((input_size, hidden_size)::Pair; kwargs...) return RAN(cell) end -function (ran::RAN)(inp) - state = zeros_like(inp, size(ran.cell.Wh, 2)) - c_state = zeros_like(state) - return ran(inp, (state, c_state)) -end - function (ran::RAN)(inp, (state, c_state)) @assert ndims(inp) == 2 || ndims(inp) == 3 new_state = [] diff --git a/src/scrn_cell.jl b/src/scrn_cell.jl index 3fa6be7..5a929a6 100644 --- a/src/scrn_cell.jl +++ b/src/scrn_cell.jl @@ -1,5 +1,5 @@ #https://arxiv.org/pdf/1412.7753 -struct SCRNCell{I,H,C,V,A} +struct SCRNCell{I,H,C,V,A} <: AbstractDoubleRecurrentCell Wi::I Wh::H Wc::C @@ -80,7 +80,7 @@ Base.show(io::IO, scrn::SCRNCell) = print(io, "SCRNCell(", size(scrn.Wi, 2), " => ", size(scrn.Wi, 1)÷3, ")") -struct SCRN{M} +struct SCRN{M} <: AbstractRecurrentLayer cell::M end From 70550ba35e2bd03edf8037bc18a6d9c5425101bf Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Fri, 13 Dec 2024 17:04:41 +0100 Subject: [PATCH 3/3] typing --- Project.toml | 2 +- src/RecurrentLayers.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 265687a..1de0954 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RecurrentLayers" uuid = "78449bcf-6750-4b78-9e82-63d4a1ccdf8c" authors = ["Francesco Martinuzzi"] -version = "0.1.3" +version = "0.1.4" [deps] Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" diff --git a/src/RecurrentLayers.jl b/src/RecurrentLayers.jl index 6c0e76c..7a8d67f 100644 --- a/src/RecurrentLayers.jl +++ b/src/RecurrentLayers.jl @@ -18,7 +18,7 @@ function initialstates(rcell::AbstractDoubleRecurrentCell) return state, second_state end -function (rcell::AbstractRecurrentCell)(inp) +function (rcell::AbstractRecurrentCell)(inp::AbstractVecOrMat) state = initialstates(rcell) return rcell(inp, state) end @@ -29,7 +29,7 @@ function initialstates(rlayer::AbstractRecurrentLayer) return initialstates(rlayer.cell) end -function (rlayer::AbstractRecurrentLayer)(inp) +function (rlayer::AbstractRecurrentLayer)(inp::AbstractVecOrMat) state = initialstates(rlayer) return rcell(inp, state) end