Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding return state option to recurrent layers #2557

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 65 additions & 28 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function scan(cell, x, state)
yt, state = cell(x_t, state)
y = vcat(y, [yt])
end
return stack(y, dims = 2)
return stack(y, dims = 2), state
end

"""
Expand Down Expand Up @@ -58,16 +58,27 @@ julia> x = rand(Float32, 2, 3, 4); # in x len x batch_size
julia> y = rnn(x); # out x len x batch_size
```
"""
struct Recurrence{M}
struct Recurrence{S,M}
cell::M
end

@layer Recurrence

initialstates(rnn::Recurrence) = initialstates(rnn.cell)

function Recurrence(cell; return_state = false)
return Recurrence{return_state, typeof(cell)}(cell)
end

(rnn::Recurrence)(x::AbstractArray) = rnn(x, initialstates(rnn))
(rnn::Recurrence)(x::AbstractArray, state) = scan(rnn.cell, x, state)

function (rnn::Recurrence{false})(x::AbstractArray, state)
first(scan(rnn.cell, x, state))
end

function (rnn::Recurrence{true})(x::AbstractArray, state)
scan(rnn.cell, x, state)
end

# Vanilla RNN
@doc raw"""
Expand Down Expand Up @@ -193,8 +204,8 @@ function Base.show(io::IO, m::RNNCell)
end

@doc raw"""
RNN(in => out, σ = tanh; init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform, bias = true)
RNN(in => out, σ = tanh; return_state = false,
init_kernel = glorot_uniform, init_recurrent_kernel = glorot_uniform, bias = true)

The most basic recurrent layer. Essentially acts as a `Dense` layer, but with the
output fed back into the input each time step.
Expand All @@ -212,6 +223,7 @@ See [`RNNCell`](@ref) for a layer that processes a single time step.

- `in => out`: The input and output dimensions of the layer.
- `σ`: The non-linearity to apply to the output. Default is `tanh`.
- `return_state`: Option to return the last state together with the output. Default is `false`.
- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`.
- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`.
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
Expand Down Expand Up @@ -260,26 +272,33 @@ Flux.@layer Model
model = Model(RNN(32 => 64), zeros(Float32, 64))
```
"""
struct RNN{M}
struct RNN{S,M}
cell::M
end

@layer :noexpand RNN

initialstates(rnn::RNN) = initialstates(rnn.cell)

function RNN((in, out)::Pair, σ = tanh; cell_kwargs...)
function RNN((in, out)::Pair, σ = tanh; return_state = false, cell_kwargs...)
cell = RNNCell(in => out, σ; cell_kwargs...)
return RNN(cell)
return RNN{return_state, typeof(cell)}(cell)
end

(rnn::RNN)(x::AbstractArray) = rnn(x, initialstates(rnn))

function (m::RNN)(x::AbstractArray, h)
function (rnn::RNN{false})(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
# [x] = [in, L] or [in, L, B]
# [h] = [out] or [out, B]
return scan(m.cell, x, h)
return first(scan(rnn.cell, x, h))
end

function (rnn::RNN{true})(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
# [x] = [in, L] or [in, L, B]
# [h] = [out] or [out, B]
return scan(rnn.cell, x, h)
end

function Base.show(io::IO, m::RNN)
Expand Down Expand Up @@ -391,7 +410,7 @@ Base.show(io::IO, m::LSTMCell) =


@doc raw"""
LSTM(in => out; init_kernel = glorot_uniform,
LSTM(in => out; return_state = false, init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform, bias = true)

[Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory)
Expand All @@ -415,6 +434,7 @@ See [`LSTMCell`](@ref) for a layer that processes a single time step.
# Arguments

- `in => out`: The input and output dimensions of the layer.
- `return_state`: Option to return the last state together with the output. Default is `false`.
- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`.
- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`.
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
Expand Down Expand Up @@ -452,24 +472,29 @@ h = model(x)
size(h) # out x len x batch_size
```
"""
struct LSTM{M}
struct LSTM{S,M}
cell::M
end

@layer :noexpand LSTM

initialstates(lstm::LSTM) = initialstates(lstm.cell)

function LSTM((in, out)::Pair; cell_kwargs...)
function LSTM((in, out)::Pair; return_state = false, cell_kwargs...)
cell = LSTMCell(in => out; cell_kwargs...)
return LSTM(cell)
return LSTM{return_state, typeof(cell)}(cell)
end

(lstm::LSTM)(x::AbstractArray) = lstm(x, initialstates(lstm))

function (m::LSTM)(x::AbstractArray, state0)
function (lstm::LSTM{false})(x::AbstractArray, state0)
@assert ndims(x) == 2 || ndims(x) == 3
return scan(m.cell, x, state0)
return first(scan(lstm.cell, x, state0))
end

function (lstm::LSTM{true})(x::AbstractArray, state0)
@assert ndims(x) == 2 || ndims(x) == 3
return scan(lstm.cell, x, state0)
end

function Base.show(io::IO, m::LSTM)
Expand Down Expand Up @@ -578,7 +603,7 @@ Base.show(io::IO, m::GRUCell) =
print(io, "GRUCell(", size(m.Wi, 2), " => ", size(m.Wi, 1) ÷ 3, ")")

@doc raw"""
GRU(in => out; init_kernel = glorot_uniform,
GRU(in => out; return_state = false, init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform, bias = true)

[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v1) layer. Behaves like an
Expand All @@ -599,6 +624,7 @@ See [`GRUCell`](@ref) for a layer that processes a single time step.
# Arguments

- `in => out`: The input and output dimensions of the layer.
- `return_state`: Option to return the last state together with the output. Default is `false`.
- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`.
- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`.
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
Expand All @@ -625,24 +651,29 @@ h0 = zeros(Float32, d_out)
h = gru(x, h0) # out x len x batch_size
```
"""
struct GRU{M}
struct GRU{S,M}
cell::M
end

@layer :noexpand GRU

initialstates(gru::GRU) = initialstates(gru.cell)

function GRU((in, out)::Pair; cell_kwargs...)
function GRU((in, out)::Pair; return_state = false, cell_kwargs...)
cell = GRUCell(in => out; cell_kwargs...)
return GRU(cell)
return GRU{return_state, typeof(cell)}(cell)
end

(gru::GRU)(x::AbstractArray) = gru(x, initialstates(gru))

function (m::GRU)(x::AbstractArray, h)
function (gru::GRU{false})(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
return first(scan(gru.cell, x, h))
end

function (gru::GRU{true})(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
return scan(m.cell, x, h)
return scan(gru.cell, x, h)
end

function Base.show(io::IO, m::GRU)
Expand Down Expand Up @@ -739,7 +770,7 @@ Base.show(io::IO, m::GRUv3Cell) =


@doc raw"""
GRUv3(in => out; init_kernel = glorot_uniform,
GRUv3(in => out; return_state = false, init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform, bias = true)

[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v3) layer. Behaves like an
Expand All @@ -764,6 +795,7 @@ but only a less popular variant.
# Arguments

- `in => out`: The input and output dimensions of the layer.
- `return_state`: Option to return the last state together with the output. Default is `false`.
- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`.
- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`.
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
Expand All @@ -790,24 +822,29 @@ h0 = zeros(Float32, d_out)
h = gruv3(x, h0) # out x len x batch_size
```
"""
struct GRUv3{M}
struct GRUv3{S,M}
cell::M
end

@layer :noexpand GRUv3

initialstates(gru::GRUv3) = initialstates(gru.cell)

function GRUv3((in, out)::Pair; cell_kwargs...)
function GRUv3((in, out)::Pair; return_state = false, cell_kwargs...)
cell = GRUv3Cell(in => out; cell_kwargs...)
return GRUv3(cell)
return GRUv3{return_state, typeof(cell)}(cell)
end

(gru::GRUv3)(x::AbstractArray) = gru(x, initialstates(gru))

function (m::GRUv3)(x::AbstractArray, h)
function (gru::GRUv3{false})(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
return first(scan(gru.cell, x, h))
end

function (gru::GRUv3{true})(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
return scan(m.cell, x, h)
return scan(gru.cell, x, h)
end

function Base.show(io::IO, m::GRUv3)
Expand Down
57 changes: 56 additions & 1 deletion test/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ end
@test y isa Array{Float32, 2}
@test size(y) == (4, 3)
test_gradients(model, x)

# testing return state
model = ModelRNN(RNN(2 => 4; return_state = true), zeros(Float32, 4))
x = rand(Float32, 2, 3, 1)
y, last_state = model(x)
@test y isa Array{Float32, 3}
@test size(y) == (4, 3, 1)

@test last_state isa Matrix{Float32}
@test size(last_state) == (4, 1)
end

@testset "LSTMCell" begin
Expand Down Expand Up @@ -172,6 +182,18 @@ end
# no initial state same as zero initial state
h1 = lstm(x, (zeros(Float32, 4), zeros(Float32, 4)))
@test h ≈ h1

# testing return state
model = ModelLSTM(LSTM(2 => 4; return_state = true), zeros(Float32, 4), zeros(Float32, 4))
x = rand(Float32, 2, 3, 1)
y, last_state = model(x)
@test y isa Array{Float32, 3}
@test size(y) == (4, 3, 1)

@test last_state[1] isa Matrix{Float32}
@test last_state[2] isa Matrix{Float32}
@test size(last_state[1]) == (4, 1)
@test size(last_state[2]) == (4, 1)
end

@testset "GRUCell" begin
Expand Down Expand Up @@ -236,6 +258,16 @@ end
gru = GRU(2 => 4, bias=false)
@test length(Flux.trainables(gru)) == 2
test_gradients(gru, x)

# testing return state
model = ModelGRU(GRU(2 => 4; return_state = true), zeros(Float32, 4))
x = rand(Float32, 2, 3, 1)
y, last_state = model(x)
@test y isa Array{Float32, 3}
@test size(y) == (4, 3, 1)

@test last_state isa Matrix{Float32}
@test size(last_state) == (4, 1)
end

@testset "GRUv3Cell" begin
Expand Down Expand Up @@ -289,13 +321,36 @@ end

# no initial state same as zero initial state
@test gru(x) ≈ gru(x, zeros(Float32, 4))

# testing return state
model = ModelGRUv3(GRUv3(2 => 4; return_state = true), zeros(Float32, 4))
x = rand(Float32, 2, 3, 1)
y, last_state = model(x)
@test y isa Array{Float32, 3}
@test size(y) == (4, 3, 1)

@test last_state isa Matrix{Float32}
@test size(last_state) == (4, 1)
end

@testset "Recurrence" begin
x = rand(Float32, 2, 3, 4)
for rnn in [RNN(2 => 3), LSTM(2 => 3), GRU(2 => 3)]
for rnn in [RNN(2 => 3), LSTM(2 => 3), GRU(2 => 3), GRUv3(2 => 3)]
cell = rnn.cell
rec = Recurrence(cell)
@test rec(x) ≈ rnn(x)
end

for rnn in [RNN(2 => 3; return_state = true), LSTM(2 => 3; return_state = true),
GRU(2 => 3; return_state = true), GRUv3(2 => 3; return_state = true)]
cell = rnn.cell
rec = Recurrence(cell; return_state = true)
@test rec(x)[1] ≈ rnn(x)[1]
if !(typeof(rnn) <: LSTM)
@test rec(x)[2] ≈ rnn(x)[2]
else
@test rec(x)[2][1] ≈ rnn(x)[2][1]
@test rec(x)[2][2] ≈ rnn(x)[2][2]
end
end
end
Loading