Skip to content

Commit

Permalink
rhn fixes, general improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco committed Jan 9, 2025
1 parent 157a4e9 commit 07b536a
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 47 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ version = "0.2.4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"

[compat]
Compat = "4.16.0"
Flux = "0.16"
Functors = "0.5.2"
NNlib = "0.9.27"
julia = "1.10"

[extras]
Expand Down
58 changes: 30 additions & 28 deletions src/RecurrentLayers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,17 @@ module RecurrentLayers

using Compat: @compat
using Flux: _size_check, _match_eltype, chunk, create_bias,
zeros_like, glorot_uniform, scan, @layer, default_rng, tanh_fast
zeros_like, glorot_uniform, scan, @layer,
default_rng, Chain, Dropout
import Flux: initialstates
import Functors: functor

rlayers = (:FastRNN, :FastGRNN, :IndRNN, :LightRU, :LiGRU, :MGU, :MUT1,
:MUT2, :MUT3, :NAS, :PeepholeLSTM, :RAN, :SCRN)

rcells = (:FastRNNCell, :FastGRNNCell, :IndRNNCell, :LightRUCell, :LiGRUCell,
:MGUCell, :MUT1Cell, :MUT2Cell, :MUT3Cell, :NASCell, :PeepholeLSTMCell,
:RANCell, :SCRNCell)

for (rlayer,rcell) in zip(rlayers, rcells)
@eval begin
function ($rlayer)(rc::$rcell; return_state::Bool = false)
return $rlayer{return_state, typeof(rc)}(rc)
end

# why wont' this work?
#function functor(rl::$rlayer{S}) where {S}
# params = (cell = rl.cell)
# reconstruct = p -> $rlayer{S, typeof(p.cell)}(p.cell)
# return params, reconstruct
#end
end
end
using NNlib: fast_act, sigmoid_fast, tanh_fast, relu

export MGUCell, LiGRUCell, IndRNNCell, RANCell, LightRUCell, RHNCell,
RHNCellUnit, NASCell, MUT1Cell, MUT2Cell, MUT3Cell, SCRNCell, PeepholeLSTMCell,
FastRNNCell, FastGRNNCell

RHNCellUnit, NASCell, MUT1Cell, MUT2Cell, MUT3Cell, SCRNCell, PeepholeLSTMCell,
FastRNNCell, FastGRNNCell
export MGU, LiGRU, IndRNN, RAN, LightRU, NAS, RHN, MUT1, MUT2, MUT3,
SCRN, PeepholeLSTM, FastRNN, FastGRNN

SCRN, PeepholeLSTM, FastRNN, FastGRNN
export StackedRNN

@compat(public, (initialstates))
Expand All @@ -55,4 +33,28 @@ include("cells/fastrnn_cell.jl")

include("wrappers/stackedrnn.jl")


### fallbacks for functors ###
rlayers = (:FastRNN, :FastGRNN, :IndRNN, :LightRU, :LiGRU, :MGU, :MUT1,
:MUT2, :MUT3, :NAS, :PeepholeLSTM, :RAN, :SCRN)

rcells = (:FastRNNCell, :FastGRNNCell, :IndRNNCell, :LightRUCell, :LiGRUCell,
:MGUCell, :MUT1Cell, :MUT2Cell, :MUT3Cell, :NASCell, :PeepholeLSTMCell,
:RANCell, :SCRNCell)

for (rlayer,rcell) in zip(rlayers, rcells)
@eval begin
function ($rlayer)(rc::$rcell; return_state::Bool = false)
return $rlayer{return_state, typeof(rc)}(rc)
end

# why wont' this work?
#function functor(rl::$rlayer{S}) where {S}
# params = (cell = rl.cell)
# reconstruct = p -> $rlayer{S, typeof(p.cell)}(p.cell)
# return params, reconstruct
#end
end
end

end #module
2 changes: 1 addition & 1 deletion src/cells/indrnn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ end

function (indrnn::IndRNNCell)(inp::AbstractVecOrMat, state::AbstractVecOrMat)
_size_check(indrnn, inp, 1 => size(indrnn.Wi, 2))
σ = NNlib.fast_act(indrnn.σ, inp)
σ = fast_act(indrnn.σ, inp)
state = σ.(indrnn.Wi*inp .+ indrnn.Wh .* state .+ indrnn.b)
return state, state
end
Expand Down
29 changes: 11 additions & 18 deletions src/cells/rhn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ function (rhn::RHNCell)(inp::AbstractArray, state::AbstractVecOrMat)
pre_h, pre_t, pre_c = layer(inp_combined)

# Apply nonlinearities
hidden_gate = tanh.(pre_h)
transform_gate = sigmoid.(pre_t)
carry_gate = sigmoid.(pre_c)
hidden_gate = tanh_fast.(pre_h)
transform_gate = sigmoid_fast.(pre_t)
carry_gate = sigmoid_fast.(pre_c)

# Highway component
if rhn.couple_carry
Expand Down Expand Up @@ -167,29 +167,22 @@ c_{\ell}^{[t]} &= \sigma(W_c x^{[t]}\mathbb{I}_{\ell = 1} + U_{c_{\ell}} s_{\ell
\end{aligned}
```
"""
struct RHN{M}
struct RHN{S,M} <: AbstractRecurrentLayer{S}
cell::M
end

@layer :noexpand RHN

function RHN((input_size, hidden_size)::Pair, depth::Integer=3; kwargs...)
function RHN((input_size, hidden_size)::Pair, depth::Integer=3;
return_state::Bool = false, kwargs...)
cell = RHNCell(input_size => hidden_size, depth; kwargs...)
return RHN(cell)
return RHN{return_state, typeof(cell)}(cell)
end

function initialstates(rhn::RHN)
return initialstates(rhn.cell)
end

function (rhn::RHN)(inp::AbstractArray)
state = initialstates(rhn)
return rhn(inp, state)
end

function (rhn::RHN)(inp::AbstractArray, state::AbstractVecOrMat)
@assert ndims(inp) == 2 || ndims(inp) == 3
return scan(rhn.cell, inp, state)
function functor(rhn::RHN{S}) where {S}
params = (cell = rhn.cell,)
reconstruct = p -> RHN{S, typeof(p.cell)}(p.cell)
return params, reconstruct
end

function colify(x::AbstractArray)
Expand Down

0 comments on commit 07b536a

Please sign in to comment.