Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding return state option to recurrent layers #2557

Merged
merged 4 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 113 additions & 32 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
yt, state = cell(x_t, state)
y = vcat(y, [yt])
end
return stack(y, dims = 2)
return stack(y, dims = 2), state
end

"""
Expand Down Expand Up @@ -58,16 +58,27 @@
julia> y = rnn(x); # out x len x batch_size
```
"""
struct Recurrence{M}
struct Recurrence{S,M}
cell::M
end

@layer Recurrence

initialstates(rnn::Recurrence) = initialstates(rnn.cell)

function Recurrence(cell; return_state = false)
return Recurrence{return_state, typeof(cell)}(cell)

Check warning on line 70 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L69-L70

Added lines #L69 - L70 were not covered by tests
end

(rnn::Recurrence)(x::AbstractArray) = rnn(x, initialstates(rnn))
(rnn::Recurrence)(x::AbstractArray, state) = scan(rnn.cell, x, state)

function (rnn::Recurrence{false})(x::AbstractArray, state)
first(scan(rnn.cell, x, state))

Check warning on line 76 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L75-L76

Added lines #L75 - L76 were not covered by tests
end

function (rnn::Recurrence{true})(x::AbstractArray, state)
scan(rnn.cell, x, state)

Check warning on line 80 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L79-L80

Added lines #L79 - L80 were not covered by tests
end

# Vanilla RNN
@doc raw"""
Expand Down Expand Up @@ -193,8 +204,8 @@
end

@doc raw"""
RNN(in => out, σ = tanh; init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform, bias = true)
RNN(in => out, σ = tanh; return_state = false,
init_kernel = glorot_uniform, init_recurrent_kernel = glorot_uniform, bias = true)

The most basic recurrent layer. Essentially acts as a `Dense` layer, but with the
output fed back into the input each time step.
Expand All @@ -212,6 +223,7 @@

- `in => out`: The input and output dimensions of the layer.
- `σ`: The non-linearity to apply to the output. Default is `tanh`.
- `return_state`: Option to return the last state together with the output. Default is `false`.
- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`.
- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`.
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
Expand All @@ -227,7 +239,8 @@
If given, it is a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).

Returns all new hidden states `h_t` as an array of size `out x len x batch_size`.
Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. When `return_state = true` it returns
a tuple of the hidden stats `h_t` and the last state of the iteration.

# Examples

Expand Down Expand Up @@ -260,26 +273,43 @@
model = Model(RNN(32 => 64), zeros(Float32, 64))
```
"""
struct RNN{M}
struct RNN{S,M}
cell::M
end

@layer :noexpand RNN

initialstates(rnn::RNN) = initialstates(rnn.cell)

function RNN((in, out)::Pair, σ = tanh; cell_kwargs...)
function RNN((in, out)::Pair, σ = tanh; return_state = false, cell_kwargs...)
cell = RNNCell(in => out, σ; cell_kwargs...)
return RNN(cell)
return RNN{return_state, typeof(cell)}(cell)
end

function RNN(cell::RNNCell; return_state::Bool=false)
RNN{return_state, typeof(cell)}(cell)

Check warning on line 290 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L289-L290

Added lines #L289 - L290 were not covered by tests
end

(rnn::RNN)(x::AbstractArray) = rnn(x, initialstates(rnn))

function (m::RNN)(x::AbstractArray, h)
function (rnn::RNN{false})(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
# [x] = [in, L] or [in, L, B]
# [h] = [out] or [out, B]
return scan(m.cell, x, h)
return first(scan(rnn.cell, x, h))
end

function (rnn::RNN{true})(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3

Check warning on line 303 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L302-L303

Added lines #L302 - L303 were not covered by tests
# [x] = [in, L] or [in, L, B]
# [h] = [out] or [out, B]
return scan(rnn.cell, x, h)

Check warning on line 306 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L306

Added line #L306 was not covered by tests
end

function Functors.functor(rnn::RNN{S}) where {S}
params = (cell = rnn.cell,)
reconstruct = p -> RNN{S, typeof(p.cell)}(p.cell)
return params, reconstruct
end

function Base.show(io::IO, m::RNN)
Expand Down Expand Up @@ -391,7 +421,7 @@


@doc raw"""
LSTM(in => out; init_kernel = glorot_uniform,
LSTM(in => out; return_state = false, init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform, bias = true)

[Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory)
Expand All @@ -415,6 +445,7 @@
# Arguments

- `in => out`: The input and output dimensions of the layer.
- `return_state`: Option to return the last state together with the output. Default is `false`.
- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`.
- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`.
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
Expand All @@ -430,7 +461,8 @@
They should be vectors of size `out` or matrices of size `out x batch_size`.
If not provided, they are assumed to be vectors of zeros, initialized by [`initialstates`](@ref).

Returns all new hidden states `h_t` as an array of size `out x len` or `out x len x batch_size`.
Returns all new hidden states `h_t` as an array of size `out x len` or `out x len x batch_size`. When `return_state = true` it returns
a tuple of the hidden stats `h_t` and the last state of the iteration.

# Examples

Expand All @@ -452,24 +484,39 @@
size(h) # out x len x batch_size
```
"""
struct LSTM{M}
struct LSTM{S,M}
cell::M
end

@layer :noexpand LSTM

initialstates(lstm::LSTM) = initialstates(lstm.cell)

function LSTM((in, out)::Pair; cell_kwargs...)
function LSTM((in, out)::Pair; return_state = false, cell_kwargs...)
cell = LSTMCell(in => out; cell_kwargs...)
return LSTM(cell)
return LSTM{return_state, typeof(cell)}(cell)
end

function LSTM(cell::LSTMCell; return_state::Bool=false)
LSTM{return_state, typeof(cell)}(cell)

Check warning on line 501 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L500-L501

Added lines #L500 - L501 were not covered by tests
end

(lstm::LSTM)(x::AbstractArray) = lstm(x, initialstates(lstm))

function (m::LSTM)(x::AbstractArray, state0)
function (lstm::LSTM{false})(x::AbstractArray, state0)
@assert ndims(x) == 2 || ndims(x) == 3
return scan(m.cell, x, state0)
return first(scan(lstm.cell, x, state0))
end

function (lstm::LSTM{true})(x::AbstractArray, state0)
@assert ndims(x) == 2 || ndims(x) == 3
return scan(lstm.cell, x, state0)

Check warning on line 513 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L511-L513

Added lines #L511 - L513 were not covered by tests
end

function Functors.functor(lstm::LSTM{S}) where {S}
params = (cell = lstm.cell,)
reconstruct = p -> LSTM{S, typeof(p.cell)}(p.cell)
return params, reconstruct
end

function Base.show(io::IO, m::LSTM)
Expand Down Expand Up @@ -578,7 +625,7 @@
print(io, "GRUCell(", size(m.Wi, 2), " => ", size(m.Wi, 1) ÷ 3, ")")

@doc raw"""
GRU(in => out; init_kernel = glorot_uniform,
GRU(in => out; return_state = false, init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform, bias = true)

[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v1) layer. Behaves like an
Expand All @@ -599,6 +646,7 @@
# Arguments

- `in => out`: The input and output dimensions of the layer.
- `return_state`: Option to return the last state together with the output. Default is `false`.
- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`.
- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`.
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
Expand All @@ -613,7 +661,8 @@
- `h`: The initial hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).

Returns all new hidden states `h_t` as an array of size `out x len x batch_size`.
Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. When `return_state = true` it returns
a tuple of the hidden stats `h_t` and the last state of the iteration.

# Examples

Expand All @@ -625,24 +674,39 @@
h = gru(x, h0) # out x len x batch_size
```
"""
struct GRU{M}
struct GRU{S,M}
cell::M
end

@layer :noexpand GRU

initialstates(gru::GRU) = initialstates(gru.cell)

function GRU((in, out)::Pair; cell_kwargs...)
function GRU((in, out)::Pair; return_state = false, cell_kwargs...)
cell = GRUCell(in => out; cell_kwargs...)
return GRU(cell)
return GRU{return_state, typeof(cell)}(cell)
end

function GRU(cell::GRUCell; return_state::Bool=false)
GRU{return_state, typeof(cell)}(cell)

Check warning on line 691 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L690-L691

Added lines #L690 - L691 were not covered by tests
end

(gru::GRU)(x::AbstractArray) = gru(x, initialstates(gru))

function (m::GRU)(x::AbstractArray, h)
function (gru::GRU{false})(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
return first(scan(gru.cell, x, h))
end

function (gru::GRU{true})(x::AbstractArray, h)

Check warning on line 701 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L701

Added line #L701 was not covered by tests
@assert ndims(x) == 2 || ndims(x) == 3
return scan(m.cell, x, h)
return scan(gru.cell, x, h)

Check warning on line 703 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L703

Added line #L703 was not covered by tests
end

function Functors.functor(gru::GRU{S}) where {S}
params = (cell = gru.cell,)
reconstruct = p -> GRU{S, typeof(p.cell)}(p.cell)
return params, reconstruct
end

function Base.show(io::IO, m::GRU)
Expand Down Expand Up @@ -739,7 +803,7 @@


@doc raw"""
GRUv3(in => out; init_kernel = glorot_uniform,
GRUv3(in => out; return_state = false, init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform, bias = true)

[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v3) layer. Behaves like an
Expand All @@ -764,6 +828,7 @@
# Arguments

- `in => out`: The input and output dimensions of the layer.
- `return_state`: Option to return the last state together with the output. Default is `false`.
- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`.
- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`.
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
Expand All @@ -778,7 +843,8 @@
- `h`: The initial hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).

Returns all new hidden states `h_t` as an array of size `out x len x batch_size`.
Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. When `return_state = true` it returns
a tuple of the hidden stats `h_t` and the last state of the iteration.

# Examples

Expand All @@ -790,24 +856,39 @@
h = gruv3(x, h0) # out x len x batch_size
```
"""
struct GRUv3{M}
struct GRUv3{S,M}
cell::M
end

@layer :noexpand GRUv3

initialstates(gru::GRUv3) = initialstates(gru.cell)

function GRUv3((in, out)::Pair; cell_kwargs...)
function GRUv3((in, out)::Pair; return_state = false, cell_kwargs...)
cell = GRUv3Cell(in => out; cell_kwargs...)
return GRUv3(cell)
return GRUv3{return_state, typeof(cell)}(cell)
end

function GRUv3(cell::GRUv3Cell; return_state::Bool=false)
GRUv3{return_state, typeof(cell)}(cell)

Check warning on line 873 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L872-L873

Added lines #L872 - L873 were not covered by tests
end

(gru::GRUv3)(x::AbstractArray) = gru(x, initialstates(gru))

function (m::GRUv3)(x::AbstractArray, h)
function (gru::GRUv3{false})(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
return scan(m.cell, x, h)
return first(scan(gru.cell, x, h))
end

function (gru::GRUv3{true})(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
return scan(gru.cell, x, h)

Check warning on line 885 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L883-L885

Added lines #L883 - L885 were not covered by tests
end

function Functors.functor(gru::GRUv3{S}) where {S}
params = (cell = gru.cell,)
reconstruct = p -> GRUv3{S, typeof(p.cell)}(p.cell)
return params, reconstruct
end

function Base.show(io::IO, m::GRUv3)
Expand Down
Loading
Loading