diff --git a/src/cells/fastrnn_cell.jl b/src/cells/fastrnn_cell.jl index 3aa61f5..b99189b 100644 --- a/src/cells/fastrnn_cell.jl +++ b/src/cells/fastrnn_cell.jl @@ -86,7 +86,8 @@ Base.show(io::IO, fastrnn::FastRNNCell) = @doc raw""" - FastRNN((input_size => hidden_size), [activation]; kwargs...) + FastRNN((input_size => hidden_size), [activation]; + return_state = false, kwargs...) [Fast recurrent neural network](https://arxiv.org/abs/1901.02358). See [`FastRNNCell`](@ref) for a layer that processes a single sequences. @@ -95,6 +96,7 @@ See [`FastRNNCell`](@ref) for a layer that processes a single sequences. - `input_size => hidden_size`: input and inner dimension of the layer - `activation`: the activation function, defaults to `tanh_fast` +- `return_state`: Option to return the last state together with the output. Default is `false`. - `init_kernel`: initializer for the input to hidden weights - `init_recurrent_kernel`: initializer for the hidden to hidden weights - `bias`: include a bias or not. Default is `true` @@ -122,17 +124,20 @@ h_t &= \alpha \tilde{h}_t + \beta h_{t-1} ## Returns - New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. + When `return_state = true` it returns a tuple of the hidden stats `new_states` and + the last state of the iteration. """ -struct FastRNN{M} <: AbstractRecurrentLayer +struct FastRNN{S,M} <: AbstractRecurrentLayer cell::M end Flux.@layer :noexpand FastRNN function FastRNN((input_size, hidden_size)::Pair, activation = tanh_fast; + return_state = false, kwargs...) cell = FastRNNCell(input_size => hidden_size, activation; kwargs...) - return FastRNN(cell) + return FastRNN{return_state, typeof(cell)}(cell) end function Base.show(io::IO, fastrnn::FastRNN) @@ -235,7 +240,8 @@ Base.show(io::IO, fastgrnn::FastGRNNCell) = @doc raw""" - FastGRNN((input_size => hidden_size), [activation]; kwargs...) + FastGRNN((input_size => hidden_size), [activation]; + return_state = false, kwargs...) [Fast recurrent neural network](https://arxiv.org/abs/1901.02358). See [`FastGRNNCell`](@ref) for a layer that processes a single sequences. @@ -244,6 +250,7 @@ See [`FastGRNNCell`](@ref) for a layer that processes a single sequences. - `input_size => hidden_size`: input and inner dimension of the layer - `activation`: the activation function, defaults to `tanh_fast` +- `return_state`: Option to return the last state together with the output. Default is `false`. - `init_kernel`: initializer for the input to hidden weights - `init_recurrent_kernel`: initializer for the hidden to hidden weights - `bias`: include a bias or not. Default is `true` @@ -273,17 +280,20 @@ h_t &= \big((\zeta (1 - z_t) + \nu) \odot \tilde{h}_t\big) + z_t \odot h_{t-1} ## Returns - New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. + When `return_state = true` it returns a tuple of the hidden stats `new_states` and + the last state of the iteration. """ -struct FastGRNN{M} <: AbstractRecurrentLayer +struct FastGRNN{S,M} <: AbstractRecurrentLayer cell::M end Flux.@layer :noexpand FastGRNN function FastGRNN((input_size, hidden_size)::Pair, activation = tanh_fast; + return_state = false, kwargs...) cell = FastGRNNCell(input_size => hidden_size, activation; kwargs...) - return FastGRNN(cell) + return FastGRNN{return_state, typeof(cell)}(cell) end function Base.show(io::IO, fastgrnn::FastGRNN) diff --git a/src/cells/indrnn_cell.jl b/src/cells/indrnn_cell.jl index 770f2f0..b0569a4 100644 --- a/src/cells/indrnn_cell.jl +++ b/src/cells/indrnn_cell.jl @@ -73,8 +73,8 @@ function Base.show(io::IO, indrnn::IndRNNCell) end @doc raw""" - IndRNN((input_size, hidden_size)::Pair, σ = tanh, σ=relu; - kwargs...) + IndRNN((input_size, hidden_size)::Pair, σ = tanh; + return_state = false, kwargs...) [Independently recurrent network](https://arxiv.org/pdf/1803.04831). See [`IndRNNCell`](@ref) for a layer that processes a single sequence. @@ -83,6 +83,7 @@ See [`IndRNNCell`](@ref) for a layer that processes a single sequence. - `input_size => hidden_size`: input and inner dimension of the layer - `σ`: activation function. Default is `tanh` +- `return_state`: Option to return the last state together with the output. Default is `false`. - `init_kernel`: initializer for the input to hidden weights - `init_recurrent_kernel`: initializer for the hidden to hidden weights - `bias`: include a bias or not. Default is `true` @@ -106,16 +107,20 @@ See [`IndRNNCell`](@ref) for a layer that processes a single sequence. ## Returns - New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. + When `return_state = true` it returns a tuple of the hidden stats `new_states` and + the last state of the iteration. """ -struct IndRNN{M} <: AbstractRecurrentLayer +struct IndRNN{S,M} <: AbstractRecurrentLayer cell::M end Flux.@layer :noexpand IndRNN -function IndRNN((input_size, hidden_size)::Pair, σ = tanh; kwargs...) +function IndRNN((input_size, hidden_size)::Pair, σ = tanh; + return_state = false, + kwargs...) cell = IndRNNCell(input_size => hidden_size, σ; kwargs...) - return IndRNN(cell) + return IndRNN{return_state, typeof(cell)}(cell) end function Base.show(io::IO, indrnn::IndRNN) diff --git a/src/cells/lightru_cell.jl b/src/cells/lightru_cell.jl index 669d0af..a1ba7d3 100644 --- a/src/cells/lightru_cell.jl +++ b/src/cells/lightru_cell.jl @@ -80,7 +80,8 @@ Base.show(io::IO, lightru::LightRUCell) = @doc raw""" - LightRU((input_size => hidden_size)::Pair; kwargs...) + LightRU((input_size => hidden_size); + return_state = false, kwargs...) [Light recurrent unit network](https://www.mdpi.com/2079-9292/13/16/3204). See [`LightRUCell`](@ref) for a layer that processes a single sequence. @@ -88,6 +89,7 @@ See [`LightRUCell`](@ref) for a layer that processes a single sequence. # Arguments - `input_size => hidden_size`: input and inner dimension of the layer +- `return_state`: Option to return the last state together with the output. Default is `false`. - `init_kernel`: initializer for the input to hidden weights - `init_recurrent_kernel`: initializer for the hidden to hidden weights - `bias`: include a bias or not. Default is `true` @@ -116,16 +118,20 @@ h_t &= (1 - f_t) \odot h_{t-1} + f_t \odot \tilde{h}_t. ## Returns - New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. + When `return_state = true` it returns a tuple of the hidden stats `new_states` and + the last state of the iteration. """ -struct LightRU{M} <: AbstractRecurrentLayer +struct LightRU{S,M} <: AbstractRecurrentLayer cell::M end Flux.@layer :noexpand LightRU -function LightRU((input_size, hidden_size)::Pair; kwargs...) +function LightRU((input_size, hidden_size)::Pair; + return_state = false, + kwargs...) cell = LightRUCell(input_size => hidden_size; kwargs...) - return LightRU(cell) + return LightRU{return_state, typeof(cell)}(cell) end function Base.show(io::IO, lightru::LightRU) diff --git a/src/cells/ligru_cell.jl b/src/cells/ligru_cell.jl index 3caf969..4f8af5a 100644 --- a/src/cells/ligru_cell.jl +++ b/src/cells/ligru_cell.jl @@ -81,7 +81,8 @@ Base.show(io::IO, ligru::LiGRUCell) = @doc raw""" - LiGRU((input_size => hidden_size)::Pair; kwargs...) + LiGRU((input_size => hidden_size); + return_state = false, kwargs...) [Light gated recurrent network](https://arxiv.org/pdf/1803.10225). The implementation does not include the batch normalization as @@ -91,6 +92,7 @@ See [`LiGRUCell`](@ref) for a layer that processes a single sequence. # Arguments - `input_size => hidden_size`: input and inner dimension of the layer +- `return_state`: Option to return the last state together with the output. Default is `false`. - `init_kernel`: initializer for the input to hidden weights - `init_recurrent_kernel`: initializer for the hidden to hidden weights - `bias`: include a bias or not. Default is `true` @@ -119,16 +121,20 @@ h_t &= z_t \odot h_{t-1} + (1 - z_t) \odot \tilde{h}_t ## Returns - New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. + When `return_state = true` it returns a tuple of the hidden stats `new_states` and + the last state of the iteration. """ -struct LiGRU{M} <: AbstractRecurrentLayer +struct LiGRU{S,M} <: AbstractRecurrentLayer cell::M end Flux.@layer :noexpand LiGRU -function LiGRU((input_size, hidden_size)::Pair; kwargs...) +function LiGRU((input_size, hidden_size)::Pair; + return_state = false, + kwargs...) cell = LiGRUCell(input_size => hidden_size; kwargs...) - return LiGRU(cell) + return LiGRU{return_state, typeof(cell)}(cell) end function Base.show(io::IO, ligru::LiGRU) diff --git a/src/cells/mgu_cell.jl b/src/cells/mgu_cell.jl index b7a693c..29c912c 100644 --- a/src/cells/mgu_cell.jl +++ b/src/cells/mgu_cell.jl @@ -79,7 +79,8 @@ Base.show(io::IO, mgu::MGUCell) = @doc raw""" - MGU((input_size => hidden_size)::Pair; kwargs...) + MGU((input_size => hidden_size); + return_state = false, kwargs...) [Minimal gated unit network](https://arxiv.org/pdf/1603.09420). See [`MGUCell`](@ref) for a layer that processes a single sequence. @@ -87,6 +88,7 @@ See [`MGUCell`](@ref) for a layer that processes a single sequence. # Arguments - `input_size => hidden_size`: input and inner dimension of the layer +- `return_state`: Option to return the last state together with the output. Default is `false`. - `init_kernel`: initializer for the input to hidden weights - `init_recurrent_kernel`: initializer for the hidden to hidden weights - `bias`: include a bias or not. Default is `true` @@ -115,16 +117,20 @@ h_t &= (1 - f_t) \odot h_{t-1} + f_t \odot \tilde{h}_t ## Returns - New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. + When `return_state = true` it returns a tuple of the hidden stats `new_states` and + the last state of the iteration. """ -struct MGU{M} <: AbstractRecurrentLayer +struct MGU{S,M} <: AbstractRecurrentLayer cell::M end Flux.@layer :noexpand MGU -function MGU((input_size, hidden_size)::Pair; kwargs...) +function MGU((input_size, hidden_size)::Pair; + return_state = false, + kwargs...) cell = MGUCell(input_size => hidden_size; kwargs...) - return MGU(cell) + return MGU{return_state, typeof(cell)}(cell) end function Base.show(io::IO, mgu::MGU)