Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding return_state to recurrent layers #39

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions src/cells/fastrnn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ Base.show(io::IO, fastrnn::FastRNNCell) =


@doc raw"""
FastRNN((input_size => hidden_size), [activation]; kwargs...)
FastRNN((input_size => hidden_size), [activation];
return_state = false, kwargs...)

[Fast recurrent neural network](https://arxiv.org/abs/1901.02358).
See [`FastRNNCell`](@ref) for a layer that processes a single sequences.
Expand All @@ -95,6 +96,7 @@ See [`FastRNNCell`](@ref) for a layer that processes a single sequences.

- `input_size => hidden_size`: input and inner dimension of the layer
- `activation`: the activation function, defaults to `tanh_fast`
- `return_state`: Option to return the last state together with the output. Default is `false`.
- `init_kernel`: initializer for the input to hidden weights
- `init_recurrent_kernel`: initializer for the hidden to hidden weights
- `bias`: include a bias or not. Default is `true`
Expand Down Expand Up @@ -122,17 +124,20 @@ h_t &= \alpha \tilde{h}_t + \beta h_{t-1}

## Returns
- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`.
When `return_state = true` it returns a tuple of the hidden stats `new_states` and
the last state of the iteration.
"""
struct FastRNN{M} <: AbstractRecurrentLayer
struct FastRNN{S,M} <: AbstractRecurrentLayer
cell::M
end

Flux.@layer :noexpand FastRNN

function FastRNN((input_size, hidden_size)::Pair, activation = tanh_fast;
return_state = false,
kwargs...)
cell = FastRNNCell(input_size => hidden_size, activation; kwargs...)
return FastRNN(cell)
return FastRNN{return_state, typeof(cell)}(cell)
end

function Base.show(io::IO, fastrnn::FastRNN)
Expand Down Expand Up @@ -235,7 +240,8 @@ Base.show(io::IO, fastgrnn::FastGRNNCell) =


@doc raw"""
FastGRNN((input_size => hidden_size), [activation]; kwargs...)
FastGRNN((input_size => hidden_size), [activation];
return_state = false, kwargs...)

[Fast recurrent neural network](https://arxiv.org/abs/1901.02358).
See [`FastGRNNCell`](@ref) for a layer that processes a single sequences.
Expand All @@ -244,6 +250,7 @@ See [`FastGRNNCell`](@ref) for a layer that processes a single sequences.

- `input_size => hidden_size`: input and inner dimension of the layer
- `activation`: the activation function, defaults to `tanh_fast`
- `return_state`: Option to return the last state together with the output. Default is `false`.
- `init_kernel`: initializer for the input to hidden weights
- `init_recurrent_kernel`: initializer for the hidden to hidden weights
- `bias`: include a bias or not. Default is `true`
Expand Down Expand Up @@ -273,17 +280,20 @@ h_t &= \big((\zeta (1 - z_t) + \nu) \odot \tilde{h}_t\big) + z_t \odot h_{t-1}

## Returns
- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`.
When `return_state = true` it returns a tuple of the hidden stats `new_states` and
the last state of the iteration.
"""
struct FastGRNN{M} <: AbstractRecurrentLayer
struct FastGRNN{S,M} <: AbstractRecurrentLayer
cell::M
end

Flux.@layer :noexpand FastGRNN

function FastGRNN((input_size, hidden_size)::Pair, activation = tanh_fast;
return_state = false,
kwargs...)
cell = FastGRNNCell(input_size => hidden_size, activation; kwargs...)
return FastGRNN(cell)
return FastGRNN{return_state, typeof(cell)}(cell)
end

function Base.show(io::IO, fastgrnn::FastGRNN)
Expand Down
15 changes: 10 additions & 5 deletions src/cells/indrnn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ function Base.show(io::IO, indrnn::IndRNNCell)
end

@doc raw"""
IndRNN((input_size, hidden_size)::Pair, σ = tanh, σ=relu;
kwargs...)
IndRNN((input_size, hidden_size)::Pair, σ = tanh;
return_state = false, kwargs...)

[Independently recurrent network](https://arxiv.org/pdf/1803.04831).
See [`IndRNNCell`](@ref) for a layer that processes a single sequence.
Expand All @@ -83,6 +83,7 @@ See [`IndRNNCell`](@ref) for a layer that processes a single sequence.

- `input_size => hidden_size`: input and inner dimension of the layer
- `σ`: activation function. Default is `tanh`
- `return_state`: Option to return the last state together with the output. Default is `false`.
- `init_kernel`: initializer for the input to hidden weights
- `init_recurrent_kernel`: initializer for the hidden to hidden weights
- `bias`: include a bias or not. Default is `true`
Expand All @@ -106,16 +107,20 @@ See [`IndRNNCell`](@ref) for a layer that processes a single sequence.

## Returns
- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`.
When `return_state = true` it returns a tuple of the hidden stats `new_states` and
the last state of the iteration.
"""
struct IndRNN{M} <: AbstractRecurrentLayer
struct IndRNN{S,M} <: AbstractRecurrentLayer
cell::M
end

Flux.@layer :noexpand IndRNN

function IndRNN((input_size, hidden_size)::Pair, σ = tanh; kwargs...)
function IndRNN((input_size, hidden_size)::Pair, σ = tanh;
return_state = false,
kwargs...)
cell = IndRNNCell(input_size => hidden_size, σ; kwargs...)
return IndRNN(cell)
return IndRNN{return_state, typeof(cell)}(cell)
end

function Base.show(io::IO, indrnn::IndRNN)
Expand Down
14 changes: 10 additions & 4 deletions src/cells/lightru_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,16 @@ Base.show(io::IO, lightru::LightRUCell) =


@doc raw"""
LightRU((input_size => hidden_size)::Pair; kwargs...)
LightRU((input_size => hidden_size);
return_state = false, kwargs...)

[Light recurrent unit network](https://www.mdpi.com/2079-9292/13/16/3204).
See [`LightRUCell`](@ref) for a layer that processes a single sequence.

# Arguments

- `input_size => hidden_size`: input and inner dimension of the layer
- `return_state`: Option to return the last state together with the output. Default is `false`.
- `init_kernel`: initializer for the input to hidden weights
- `init_recurrent_kernel`: initializer for the hidden to hidden weights
- `bias`: include a bias or not. Default is `true`
Expand Down Expand Up @@ -116,16 +118,20 @@ h_t &= (1 - f_t) \odot h_{t-1} + f_t \odot \tilde{h}_t.

## Returns
- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`.
When `return_state = true` it returns a tuple of the hidden stats `new_states` and
the last state of the iteration.
"""
struct LightRU{M} <: AbstractRecurrentLayer
struct LightRU{S,M} <: AbstractRecurrentLayer
cell::M
end

Flux.@layer :noexpand LightRU

function LightRU((input_size, hidden_size)::Pair; kwargs...)
function LightRU((input_size, hidden_size)::Pair;
return_state = false,
kwargs...)
cell = LightRUCell(input_size => hidden_size; kwargs...)
return LightRU(cell)
return LightRU{return_state, typeof(cell)}(cell)
end

function Base.show(io::IO, lightru::LightRU)
Expand Down
14 changes: 10 additions & 4 deletions src/cells/ligru_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ Base.show(io::IO, ligru::LiGRUCell) =


@doc raw"""
LiGRU((input_size => hidden_size)::Pair; kwargs...)
LiGRU((input_size => hidden_size);
return_state = false, kwargs...)

[Light gated recurrent network](https://arxiv.org/pdf/1803.10225).
The implementation does not include the batch normalization as
Expand All @@ -91,6 +92,7 @@ See [`LiGRUCell`](@ref) for a layer that processes a single sequence.
# Arguments

- `input_size => hidden_size`: input and inner dimension of the layer
- `return_state`: Option to return the last state together with the output. Default is `false`.
- `init_kernel`: initializer for the input to hidden weights
- `init_recurrent_kernel`: initializer for the hidden to hidden weights
- `bias`: include a bias or not. Default is `true`
Expand Down Expand Up @@ -119,16 +121,20 @@ h_t &= z_t \odot h_{t-1} + (1 - z_t) \odot \tilde{h}_t

## Returns
- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`.
When `return_state = true` it returns a tuple of the hidden stats `new_states` and
the last state of the iteration.
"""
struct LiGRU{M} <: AbstractRecurrentLayer
struct LiGRU{S,M} <: AbstractRecurrentLayer
cell::M
end

Flux.@layer :noexpand LiGRU

function LiGRU((input_size, hidden_size)::Pair; kwargs...)
function LiGRU((input_size, hidden_size)::Pair;
return_state = false,
kwargs...)
cell = LiGRUCell(input_size => hidden_size; kwargs...)
return LiGRU(cell)
return LiGRU{return_state, typeof(cell)}(cell)
end

function Base.show(io::IO, ligru::LiGRU)
Expand Down
14 changes: 10 additions & 4 deletions src/cells/mgu_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,16 @@ Base.show(io::IO, mgu::MGUCell) =


@doc raw"""
MGU((input_size => hidden_size)::Pair; kwargs...)
MGU((input_size => hidden_size);
return_state = false, kwargs...)

[Minimal gated unit network](https://arxiv.org/pdf/1603.09420).
See [`MGUCell`](@ref) for a layer that processes a single sequence.

# Arguments

- `input_size => hidden_size`: input and inner dimension of the layer
- `return_state`: Option to return the last state together with the output. Default is `false`.
- `init_kernel`: initializer for the input to hidden weights
- `init_recurrent_kernel`: initializer for the hidden to hidden weights
- `bias`: include a bias or not. Default is `true`
Expand Down Expand Up @@ -115,16 +117,20 @@ h_t &= (1 - f_t) \odot h_{t-1} + f_t \odot \tilde{h}_t

## Returns
- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`.
When `return_state = true` it returns a tuple of the hidden stats `new_states` and
the last state of the iteration.
"""
struct MGU{M} <: AbstractRecurrentLayer
struct MGU{S,M} <: AbstractRecurrentLayer
cell::M
end

Flux.@layer :noexpand MGU

function MGU((input_size, hidden_size)::Pair; kwargs...)
function MGU((input_size, hidden_size)::Pair;
return_state = false,
kwargs...)
cell = MGUCell(input_size => hidden_size; kwargs...)
return MGU(cell)
return MGU{return_state, typeof(cell)}(cell)
end

function Base.show(io::IO, mgu::MGU)
Expand Down
Loading