diff --git a/Project.toml b/Project.toml index a4a5e9e..8faecbc 100644 --- a/Project.toml +++ b/Project.toml @@ -7,11 +7,13 @@ version = "0.2.4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" [compat] Compat = "4.16.0" Flux = "0.16" Functors = "0.5.2" +NNlib = "0.9.27" julia = "1.10" [extras] diff --git a/src/RecurrentLayers.jl b/src/RecurrentLayers.jl index d637d37..ed791f4 100644 --- a/src/RecurrentLayers.jl +++ b/src/RecurrentLayers.jl @@ -2,39 +2,17 @@ module RecurrentLayers using Compat: @compat using Flux: _size_check, _match_eltype, chunk, create_bias, - zeros_like, glorot_uniform, scan, @layer, default_rng, tanh_fast + zeros_like, glorot_uniform, scan, @layer, + default_rng, Chain, Dropout import Flux: initialstates import Functors: functor - -rlayers = (:FastRNN, :FastGRNN, :IndRNN, :LightRU, :LiGRU, :MGU, :MUT1, - :MUT2, :MUT3, :NAS, :PeepholeLSTM, :RAN, :SCRN) - -rcells = (:FastRNNCell, :FastGRNNCell, :IndRNNCell, :LightRUCell, :LiGRUCell, - :MGUCell, :MUT1Cell, :MUT2Cell, :MUT3Cell, :NASCell, :PeepholeLSTMCell, - :RANCell, :SCRNCell) - -for (rlayer,rcell) in zip(rlayers, rcells) - @eval begin - function ($rlayer)(rc::$rcell; return_state::Bool = false) - return $rlayer{return_state, typeof(rc)}(rc) - end - - # why wont' this work? - #function functor(rl::$rlayer{S}) where {S} - # params = (cell = rl.cell) - # reconstruct = p -> $rlayer{S, typeof(p.cell)}(p.cell) - # return params, reconstruct - #end - end -end +using NNlib: fast_act, sigmoid_fast, tanh_fast, relu export MGUCell, LiGRUCell, IndRNNCell, RANCell, LightRUCell, RHNCell, -RHNCellUnit, NASCell, MUT1Cell, MUT2Cell, MUT3Cell, SCRNCell, PeepholeLSTMCell, -FastRNNCell, FastGRNNCell - + RHNCellUnit, NASCell, MUT1Cell, MUT2Cell, MUT3Cell, SCRNCell, PeepholeLSTMCell, + FastRNNCell, FastGRNNCell export MGU, LiGRU, IndRNN, RAN, LightRU, NAS, RHN, MUT1, MUT2, MUT3, -SCRN, PeepholeLSTM, FastRNN, FastGRNN - + SCRN, PeepholeLSTM, FastRNN, FastGRNN export StackedRNN @compat(public, (initialstates)) @@ -55,4 +33,28 @@ include("cells/fastrnn_cell.jl") include("wrappers/stackedrnn.jl") + +### fallbacks for functors ### +rlayers = (:FastRNN, :FastGRNN, :IndRNN, :LightRU, :LiGRU, :MGU, :MUT1, + :MUT2, :MUT3, :NAS, :PeepholeLSTM, :RAN, :SCRN) + +rcells = (:FastRNNCell, :FastGRNNCell, :IndRNNCell, :LightRUCell, :LiGRUCell, + :MGUCell, :MUT1Cell, :MUT2Cell, :MUT3Cell, :NASCell, :PeepholeLSTMCell, + :RANCell, :SCRNCell) + +for (rlayer,rcell) in zip(rlayers, rcells) + @eval begin + function ($rlayer)(rc::$rcell; return_state::Bool = false) + return $rlayer{return_state, typeof(rc)}(rc) + end + + # why wont' this work? + #function functor(rl::$rlayer{S}) where {S} + # params = (cell = rl.cell) + # reconstruct = p -> $rlayer{S, typeof(p.cell)}(p.cell) + # return params, reconstruct + #end + end +end + end #module \ No newline at end of file diff --git a/src/cells/indrnn_cell.jl b/src/cells/indrnn_cell.jl index 56c867e..3142f9b 100644 --- a/src/cells/indrnn_cell.jl +++ b/src/cells/indrnn_cell.jl @@ -61,7 +61,7 @@ end function (indrnn::IndRNNCell)(inp::AbstractVecOrMat, state::AbstractVecOrMat) _size_check(indrnn, inp, 1 => size(indrnn.Wi, 2)) - σ = NNlib.fast_act(indrnn.σ, inp) + σ = fast_act(indrnn.σ, inp) state = σ.(indrnn.Wi*inp .+ indrnn.Wh .* state .+ indrnn.b) return state, state end diff --git a/src/cells/rhn_cell.jl b/src/cells/rhn_cell.jl index 39df802..002ba9b 100644 --- a/src/cells/rhn_cell.jl +++ b/src/cells/rhn_cell.jl @@ -125,9 +125,9 @@ function (rhn::RHNCell)(inp::AbstractArray, state::AbstractVecOrMat) pre_h, pre_t, pre_c = layer(inp_combined) # Apply nonlinearities - hidden_gate = tanh.(pre_h) - transform_gate = sigmoid.(pre_t) - carry_gate = sigmoid.(pre_c) + hidden_gate = tanh_fast.(pre_h) + transform_gate = sigmoid_fast.(pre_t) + carry_gate = sigmoid_fast.(pre_c) # Highway component if rhn.couple_carry @@ -167,29 +167,22 @@ c_{\ell}^{[t]} &= \sigma(W_c x^{[t]}\mathbb{I}_{\ell = 1} + U_{c_{\ell}} s_{\ell \end{aligned} ``` """ -struct RHN{M} +struct RHN{S,M} <: AbstractRecurrentLayer{S} cell::M end @layer :noexpand RHN -function RHN((input_size, hidden_size)::Pair, depth::Integer=3; kwargs...) +function RHN((input_size, hidden_size)::Pair, depth::Integer=3; + return_state::Bool = false, kwargs...) cell = RHNCell(input_size => hidden_size, depth; kwargs...) - return RHN(cell) + return RHN{return_state, typeof(cell)}(cell) end -function initialstates(rhn::RHN) - return initialstates(rhn.cell) -end - -function (rhn::RHN)(inp::AbstractArray) - state = initialstates(rhn) - return rhn(inp, state) -end - -function (rhn::RHN)(inp::AbstractArray, state::AbstractVecOrMat) - @assert ndims(inp) == 2 || ndims(inp) == 3 - return scan(rhn.cell, inp, state) +function functor(rhn::RHN{S}) where {S} + params = (cell = rhn.cell,) + reconstruct = p -> RHN{S, typeof(p.cell)}(p.cell) + return params, reconstruct end function colify(x::AbstractArray)