Skip to content

Commit

Permalink
some details
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco committed Oct 28, 2024
1 parent caff8bf commit 35e92c1
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 25 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ RecurrentLayers.jl extends [Flux.jl](https://github.com/FluxML/Flux.jl) recurren
Currently available layers and work in progress in the short term:
- [x] Minimal gated unit (MGU) [arxiv](https://arxiv.org/abs/1603.09420)
- [x] Light gated recurrent unit (LiGRU) [arxiv](https://arxiv.org/abs/1803.10225)
- [ ] Minimal gated recurrent unit (minGRU) and minimal long short term memory (minLSTM) [arxiv](https://arxiv.org/abs/2410.01201)
- [x] Independently recurrent neural networks (IndRNN) [arxiv](https://arxiv.org/abs/1803.04831)
- [x] Recurrent addictive networks (RAN) [arxiv](https://arxiv.org/abs/1705.07393)
- [x] Recurrent highway network (RHN) [arixv](https://arxiv.org/pdf/1607.03474)
- [x] Light recurrent unit (LightRU) [pub](https://www.mdpi.com/2079-9292/13/16/3204)
- [ ] Minimal gated recurrent unit (minGRU) and minimal long short term memory (minLSTM) [arxiv](https://arxiv.org/abs/2410.01201)

## Installation

Expand Down
5 changes: 3 additions & 2 deletions src/RecurrentLayers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ using Flux
import Flux: _size_check, _match_eltype, chunk, create_bias, zeros_like
import Flux: glorot_uniform

export MGUCell, LiGRUCell, IndRNNCell, RANCell, LRUCell
export MGU, LiGRU, IndRNN, RAN
export MGUCell, LiGRUCell, IndRNNCell, RANCell, LightRUCell, RHNCell, RHNCellUnit
export MGU, LiGRU, IndRNN, RAN, LightRU

include("mgu_cell.jl")
include("ligru_cell.jl")
include("indrnn_cell.jl")
include("ran_cell.jl")
include("lru_cell.jl")
include("rhn_cell.jl")

end #module
28 changes: 14 additions & 14 deletions src/lru_cell.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
#https://www.mdpi.com/2079-9292/13/16/3204
struct LRUCell{I,H,V}
struct LightRUCell{I,H,V}
Wi::I
Wh::H
bias::V
end

function LRUCell((in, out)::Pair, σ=tanh; init = glorot_uniform, bias = true)
function LightRUCell((in, out)::Pair, σ=tanh; init = glorot_uniform, bias = true)
Wi = init(2 * out, in)
Wh = init(out, out)
b = create_bias(Wi, bias, size(Wh, 1))

return LRUCell(Wi, Wh, b)
return LightRUCell(Wi, Wh, b)
end

LRUCell(in, out; kwargs...) = LRUCell(in => out; kwargs...)
LightRUCell(in, out; kwargs...) = LightRUCell(in => out; kwargs...)

function (lru::LRUCell)(inp::AbstractVecOrMat)
function (lru::LightRUCell)(inp::AbstractVecOrMat)
state = zeros_like(inp, size(lru.Wh, 2))
return lru(inp, state)
end

function (lru::LRUCell)(inp::AbstractVecOrMat, state)
function (lru::LightRUCell)(inp::AbstractVecOrMat, state)
_size_check(lru, inp, 1 => size(lru.Wi,2))
Wi, Wh, b = lru.Wi, lru.Wh, lru.bias

Expand All @@ -34,28 +34,28 @@ function (lru::LRUCell)(inp::AbstractVecOrMat, state)
return new_state
end

Base.show(io::IO, lru::LRUCell) =
print(io, "LRUCell(", size(lru.Wi, 2), " => ", size(lru.Wi, 1)÷2, ")")
Base.show(io::IO, lru::LightRUCell) =
print(io, "LightRUCell(", size(lru.Wi, 2), " => ", size(lru.Wi, 1)÷2, ")")



struct LRU{M}
struct LightRU{M}
cell::M
end

Flux.@layer :expand LRU

function LRU((in, out)::Pair; init = glorot_uniform, bias = true)
cell = LRUCell(in => out; init, bias)
return LRU(cell)
function LightRU((in, out)::Pair; init = glorot_uniform, bias = true)
cell = LightRUCell(in => out; init, bias)
return LightRU(cell)
end

function (lru::LRU)(inp)
function (lru::LightRU)(inp)
state = zeros_like(inp, size(lru.cell.Wh, 2))
return lru(inp, state)
end

function (lru::LRU)(inp, state)
function (lru::LightRU)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
new_state = []
for inp_t in eachslice(inp, dims=2)
Expand Down
16 changes: 8 additions & 8 deletions src/rhn_cell.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
#https://arxiv.org/pdf/1607.03474
#https://github.com/jzilly/RecurrentHighwayNetworks/blob/master/rhn.py#L138C1-L180C60

struct RHNCellLayer{I,V}
struct RHNCellUnit{I,V}
weight::I
bias::V
end

function RHNCellLayer((in, out)::Pair; init = glorot_uniform, bias = true)
function RHNCellUnit((in, out)::Pair; init = glorot_uniform, bias = true)
weight = init(3 * out, in)
b = create_bias(weight, bias, size(weight, 1))
return RHNCellLayer(weight, b)
return RHNCellUnit(weight, b)
end

function (rhn::RHNCellLayer)(inp::AbstractVecOrMat)
function (rhn::RHNCellUnit)(inp::AbstractVecOrMat)
state = zeros_like(inp, size(rhn.weight, 2))
return rhn(inp, state)
end

function (rhn::RHNCellLayer)(inp::AbstractVecOrMat, state)
function (rhn::RHNCellUnit)(inp::AbstractVecOrMat, state)
_size_check(rhn, inp, 1 => size(rhn.weight, 2))
weight, bias = rhn.weight, rhn.bias

Expand All @@ -29,8 +29,8 @@ function (rhn::RHNCellLayer)(inp::AbstractVecOrMat, state)
return pre_h, pre_t, pre_c
end

Base.show(io::IO, rhn::RHNCellLayer) =
print(io, "RHNCellLayer(", size(rhn.weight, 2), " => ", size(rhn.weight, 1)÷3, ")")
Base.show(io::IO, rhn::RHNCellUnit) =
print(io, "RHNCellUnit(", size(rhn.weight, 2), " => ", size(rhn.weight, 1)÷3, ")")

struct RHNCell{C}
layers::C
Expand All @@ -48,7 +48,7 @@ function RHNCell((in, out), depth=3;
else
real_in = out
end
rhn = RHNCellLayer(real_in=>out; cell_kwargs...)
rhn = RHNCellUnit(real_in=>out; cell_kwargs...)
push!(layers, rhn)
end
return RHNCell(Chain(layers), couple_carry)
Expand Down

0 comments on commit 35e92c1

Please sign in to comment.