Skip to content

Commit

Permalink
Merge pull request #23 from MartinuzziFrancesco/fm/is
Browse files Browse the repository at this point in the history
Adding initialstates
  • Loading branch information
MartinuzziFrancesco authored Dec 13, 2024
2 parents 32b1e9b + 70550ba commit 408a5c8
Show file tree
Hide file tree
Showing 13 changed files with 69 additions and 167 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "RecurrentLayers"
uuid = "78449bcf-6750-4b78-9e82-63d4a1ccdf8c"
authors = ["Francesco Martinuzzi"]
version = "0.1.3"
version = "0.1.4"

[deps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Expand Down
30 changes: 30 additions & 0 deletions src/RecurrentLayers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,36 @@ module RecurrentLayers
using Flux
import Flux: _size_check, _match_eltype, chunk, create_bias, zeros_like
import Flux: glorot_uniform
import Flux: initialstates

abstract type AbstractRecurrentCell end
abstract type AbstractDoubleRecurrentCell <: AbstractRecurrentCell end

function initialstates(rcell::AbstractRecurrentCell)
return zeros_like(rcell.Wh, size(rcell.Wh, 2))
end

function initialstates(rcell::AbstractDoubleRecurrentCell)
state = zeros_like(rcell.Wh, size(rcell.Wh, 2))
second_state = zeros_like(rcell.Wh, size(rcell.Wh, 2))
return state, second_state
end

function (rcell::AbstractRecurrentCell)(inp::AbstractVecOrMat)
state = initialstates(rcell)
return rcell(inp, state)
end

abstract type AbstractRecurrentLayer end

function initialstates(rlayer::AbstractRecurrentLayer)
return initialstates(rlayer.cell)
end

function (rlayer::AbstractRecurrentLayer)(inp::AbstractVecOrMat)
state = initialstates(rlayer)
return rcell(inp, state)
end

export MGUCell, LiGRUCell, IndRNNCell, RANCell, LightRUCell, RHNCell,
RHNCellUnit, NASCell, MUT1Cell, MUT2Cell, MUT3Cell, SCRNCell, PeepholeLSTMCell,
Expand Down
28 changes: 4 additions & 24 deletions src/fastrnn_cell.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#https://arxiv.org/abs/1901.02358
struct FastRNNCell{I, H, V, A, B, F}
struct FastRNNCell{I, H, V, A, B, F} <: AbstractRecurrentCell
Wi::I
Wh::H
bias::V
Expand Down Expand Up @@ -53,11 +53,6 @@ function FastRNNCell((input_size, hidden_size)::Pair, activation=tanh_fast;
return FastRNNCell(Wi, Wh, b, alpha, beta, activation)
end

function (fastrnn::FastRNNCell)(inp::AbstractVecOrMat)
state = zeros_like(inp, size(fastrnn.Wh, 2))
return fastrnn(inp, state)
end

function (fastrnn::FastRNNCell)(inp::AbstractVecOrMat, state)
#checks
_size_check(fastrnn, inp, 1 => size(fastrnn.Wi,2))
Expand All @@ -77,7 +72,7 @@ Base.show(io::IO, fastrnn::FastRNNCell) =
print(io, "FastRNNCell(", size(fastrnn.Wi, 2), " => ", size(fastrnn.Wi, 1) ÷ 2, ")")


struct FastRNN{M}
struct FastRNN{M} <: AbstractRecurrentLayer
cell::M
end

Expand Down Expand Up @@ -114,11 +109,6 @@ function FastRNN((input_size, hidden_size)::Pair, activation = tanh_fast;
cell = FastRNNCell(input_size => hidden_size, activation; kwargs...)
return FastRNN(cell)
end

function (fastrnn::FastRNN)(inp)
state = zeros_like(inp, size(fastrnn.cell.Wh, 2))
return fastrnn(inp, state)
end

function (fastrnn::FastRNN)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
Expand All @@ -131,7 +121,7 @@ function (fastrnn::FastRNN)(inp, state)
end


struct FastGRNNCell{I, H, V, A, B, F}
struct FastGRNNCell{I, H, V, A, B, F} <: AbstractRecurrentCell
Wi::I
Wh::H
bias::V
Expand Down Expand Up @@ -186,11 +176,6 @@ function FastGRNNCell((input_size, hidden_size)::Pair, activation=tanh_fast;
return FastGRNNCell(Wi, Wh, b, alpha, beta, activation)
end

function (fastgrnn::FastGRNNCell)(inp::AbstractVecOrMat)
state = zeros_like(inp, size(fastgrnn.Wh, 2))
return fastgrnn(inp, state)
end

function (fastgrnn::FastGRNNCell)(inp::AbstractVecOrMat, state)
#checks
_size_check(fastgrnn, inp, 1 => size(fastgrnn.Wi,2))
Expand All @@ -214,7 +199,7 @@ Base.show(io::IO, fastgrnn::FastGRNNCell) =
print(io, "FastGRNNCell(", size(fastgrnn.Wi, 2), " => ", size(fastgrnn.Wi, 1) ÷ 2, ")")


struct FastGRNN{M}
struct FastGRNN{M} <: AbstractRecurrentLayer
cell::M
end

Expand Down Expand Up @@ -252,11 +237,6 @@ function FastGRNN((input_size, hidden_size)::Pair, activation = tanh_fast;
cell = FastGRNNCell(input_size => hidden_size, activation; kwargs...)
return FastGRNN(cell)
end

function (fastgrnn::FastGRNN)(inp)
state = zeros_like(inp, size(fastgrnn.cell.Wh, 2))
return fastgrnn(inp, state)
end

function (fastgrnn::FastGRNN)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
Expand Down
22 changes: 6 additions & 16 deletions src/indrnn_cell.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#https://arxiv.org/pdf/1803.04831
struct IndRNNCell{F,I,H,V}
struct IndRNNCell{F,I,H,V} <: AbstractRecurrentCell
σ::F
Wi::I
u::H
Wh::H
b::V
end

Expand Down Expand Up @@ -41,20 +41,15 @@ function IndRNNCell((input_size, hidden_size)::Pair, σ=relu;
init_recurrent_kernel = glorot_uniform,
bias = true)
Wi = init_kernel(hidden_size, input_size)
u = init_recurrent_kernel(hidden_size)
Wh = init_recurrent_kernel(hidden_size)
b = create_bias(Wi, bias, size(Wi, 1))
return IndRNNCell(σ, Wi, u, b)
end

function (indrnn::IndRNNCell)(x::AbstractVecOrMat)
state = zeros_like(x, size(indrnn.u, 1))
return indrnn(x, state)
return IndRNNCell(σ, Wi, Wh, b)
end

function (indrnn::IndRNNCell)(inp::AbstractVecOrMat, state::AbstractVecOrMat)
_size_check(indrnn, inp, 1 => size(indrnn.Wi, 2))
σ = NNlib.fast_act(indrnn.σ, inp)
state = σ.(indrnn.Wi*inp .+ indrnn.u .* state .+ indrnn.b)
state = σ.(indrnn.Wi*inp .+ indrnn.Wh .* state .+ indrnn.b)
return state
end

Expand All @@ -64,7 +59,7 @@ function Base.show(io::IO, indrnn::IndRNNCell)
print(io, ")")
end

struct IndRNN{M}
struct IndRNN{M} <: AbstractRecurrentLayer
cell::M
end

Expand Down Expand Up @@ -95,11 +90,6 @@ function IndRNN((input_size, hidden_size)::Pair, σ = tanh; kwargs...)
return IndRNN(cell)
end

function (indrnn::IndRNN)(inp)
state = zeros_like(inp, size(indrnn.cell.u, 1))
return indrnn(inp, state)
end

function (indrnn::IndRNN)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
new_state = []
Expand Down
28 changes: 9 additions & 19 deletions src/lightru_cell.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#https://www.mdpi.com/2079-9292/13/16/3204
struct LightRUCell{I,H,V}
struct LightRUCell{I,H,V} <: AbstractRecurrentCell
Wi::I
Wh::H
bias::V
Expand Down Expand Up @@ -47,14 +47,9 @@ function LightRUCell((input_size, hidden_size)::Pair;
return LightRUCell(Wi, Wh, b)
end

function (lru::LightRUCell)(inp::AbstractVecOrMat)
state = zeros_like(inp, size(lru.Wh, 2))
return lru(inp, state)
end

function (lru::LightRUCell)(inp::AbstractVecOrMat, state)
_size_check(lru, inp, 1 => size(lru.Wi,2))
Wi, Wh, b = lru.Wi, lru.Wh, lru.bias
function (lightru::LightRUCell)(inp::AbstractVecOrMat, state)
_size_check(lightru, inp, 1 => size(lightru.Wi,2))
Wi, Wh, b = lightru.Wi, lightru.Wh, lightru.bias

#split
gxs = chunk(Wi * inp, 2, dims=1)
Expand All @@ -66,12 +61,12 @@ function (lru::LightRUCell)(inp::AbstractVecOrMat, state)
return new_state
end

Base.show(io::IO, lru::LightRUCell) =
print(io, "LightRUCell(", size(lru.Wi, 2), " => ", size(lru.Wi, 1)÷2, ")")
Base.show(io::IO, lightru::LightRUCell) =
print(io, "LightRUCell(", size(lightru.Wi, 2), " => ", size(lightru.Wi, 1)÷2, ")")



struct LightRU{M}
struct LightRU{M} <: AbstractRecurrentLayer
cell::M
end

Expand Down Expand Up @@ -104,16 +99,11 @@ function LightRU((input_size, hidden_size)::Pair; kwargs...)
return LightRU(cell)
end

function (lru::LightRU)(inp)
state = zeros_like(inp, size(lru.cell.Wh, 2))
return lru(inp, state)
end

function (lru::LightRU)(inp, state)
function (lightru::LightRU)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
new_state = []
for inp_t in eachslice(inp, dims=2)
state = lru.cell(inp_t, state)
state = lightru.cell(inp_t, state)
new_state = vcat(new_state, [state])
end
return stack(new_state, dims=2)
Expand Down
14 changes: 2 additions & 12 deletions src/ligru_cell.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#https://arxiv.org/pdf/1803.10225
struct LiGRUCell{I, H, V}
struct LiGRUCell{I, H, V} <: AbstractRecurrentCell
Wi::I
Wh::H
bias::V
Expand Down Expand Up @@ -50,11 +50,6 @@ function LiGRUCell((input_size, hidden_size)::Pair;
return LiGRUCell(Wi, Wh, b)
end

function (ligru::LiGRUCell)(inp::AbstractVecOrMat)
state = zeros_like(inp, size(ligru.Wh, 2))
return ligru(inp, state)
end

function (ligru::LiGRUCell)(inp::AbstractVecOrMat, state)
_size_check(ligru, inp, 1 => size(ligru.Wi,2))
Wi, Wh, b = ligru.Wi, ligru.Wh, ligru.bias
Expand All @@ -69,7 +64,7 @@ function (ligru::LiGRUCell)(inp::AbstractVecOrMat, state)
end


struct LiGRU{M}
struct LiGRU{M} <: AbstractRecurrentLayer
cell::M
end

Expand Down Expand Up @@ -104,11 +99,6 @@ function LiGRU((input_size, hidden_size)::Pair; kwargs...)
return LiGRU(cell)
end

function (ligru::LiGRU)(inp)
state = zeros_like(inp, size(ligru.cell.Wh, 2))
return ligru(inp, state)
end

function (ligru::LiGRU)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
new_state = []
Expand Down
14 changes: 2 additions & 12 deletions src/mgu_cell.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#https://arxiv.org/pdf/1603.09420
struct MGUCell{I, H, V}
struct MGUCell{I, H, V} <: AbstractRecurrentCell
Wi::I
Wh::H
bias::V
Expand Down Expand Up @@ -48,11 +48,6 @@ function MGUCell((input_size, hidden_size)::Pair;
return MGUCell(Wi, Wh, b)
end

function (mgu::MGUCell)(inp::AbstractVecOrMat)
state = zeros_like(inp, size(mgu.Wh, 2))
return mgu(inp, state)
end

function (mgu::MGUCell)(inp::AbstractVecOrMat, state)
_size_check(mgu, inp, 1 => size(mgu.Wi,2))
Wi, Wh, b = mgu.Wi, mgu.Wh, mgu.bias
Expand All @@ -70,7 +65,7 @@ Base.show(io::IO, mgu::MGUCell) =
print(io, "MGUCell(", size(mgu.Wi, 2), " => ", size(mgu.Wi, 1) ÷ 2, ")")


struct MGU{M}
struct MGU{M} <: AbstractRecurrentLayer
cell::M
end

Expand Down Expand Up @@ -102,11 +97,6 @@ function MGU((input_size, hidden_size)::Pair; kwargs...)
cell = MGUCell(input_size => hidden_size; kwargs...)
return MGU(cell)
end

function (mgu::MGU)(inp)
state = zeros_like(inp, size(mgu.cell.Wh, 2))
return mgu(inp, state)
end

function (mgu::MGU)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
Expand Down
Loading

2 comments on commit 408a5c8

@MartinuzziFrancesco
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/121351

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.4 -m "<description of version>" 408a5c80a113616e6af6ec2e2c04755d18f84bfe
git push origin v0.1.4

Please sign in to comment.