From 900869e60d119b18fb4b496238d75432776c56e3 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Wed, 18 Dec 2024 10:33:32 +0100 Subject: [PATCH 1/2] rhn fixes --- src/rhn_cell.jl | 41 ++++++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/src/rhn_cell.jl b/src/rhn_cell.jl index 6a5866b..62b52c1 100644 --- a/src/rhn_cell.jl +++ b/src/rhn_cell.jl @@ -2,7 +2,7 @@ #https://github.com/jzilly/RecurrentHighwayNetworks/blob/master/rhn.py#L138C1-L180C60 struct RHNCellUnit{I,V} - weight::I + weights::I bias::V end @@ -15,18 +15,22 @@ Flux.@layer RHNCellUnit """ function RHNCellUnit((input_size, hidden_size)::Pair; init_kernel = glorot_uniform, - bias = true) + bias::Bool = true) weight = init_kernel(3 * hidden_size, input_size) b = create_bias(weight, bias, size(weight, 1)) return RHNCellUnit(weight, b) end +function initialstates(rhn::RHNCellUnit) + return zeros_like(rhn.weights, size(rhn.weights, 1) ÷ 3) +end + function (rhn::RHNCellUnit)(inp::AbstractVecOrMat) - state = zeros_like(inp, size(rhn.weight, 2)) + state = initialstates(rhn) return rhn(inp, state) end -function (rhn::RHNCellUnit)(inp::AbstractVecOrMat, state) +function (rhn::RHNCellUnit)(inp::AbstractVecOrMat, state::AbstractVecOrMat) _size_check(rhn, inp, 1 => size(rhn.weight, 2)) weight, bias = rhn.weight, rhn.bias @@ -81,7 +85,7 @@ c_{\ell}^{[t]} &= \sigma(W_c x^{[t]}\mathbb{I}_{\ell = 1} + U_{c_{\ell}} s_{\ell rnncell(inp, [state]) """ -function RHNCell((input_size, hidden_size), depth::Int = 3; +function RHNCell((input_size, hidden_size), depth::Integer = 3; couple_carry::Bool = true, #sec 5, setup cell_kwargs...) @@ -98,13 +102,16 @@ function RHNCell((input_size, hidden_size), depth::Int = 3; return RHNCell(Chain(layers), couple_carry) end +function initialstates(rhn::RHNCell) + return initialstates(first(rhn.layers)) +end -function (rhn::RHNCell)(inp, state=nothing) +function (rhn::RHNCell)(inp::AbstractArray) + state = initialstates(rhn) + return rhn(inp, state) +end - #not ideal - if state == nothing - state = zeros_like(inp, size(rhn.layers.layers[2].weight, 2)) - end +function (rhn::RHNCell)(inp::AbstractArray, state::AbstractVecOrMat) current_state = state @@ -141,7 +148,7 @@ end Flux.@layer :noexpand RHN @doc raw""" - RHN((input_size => hidden_size)::Pair depth=3; kwargs...) + RHN((input_size => hidden_size) depth=3; kwargs...) [Recurrent highway network](https://arxiv.org/pdf/1607.03474). See [`RHNCellUnit`](@ref) for a the unit component of this layer. @@ -166,17 +173,21 @@ c_{\ell}^{[t]} &= \sigma(W_c x^{[t]}\mathbb{I}_{\ell = 1} + U_{c_{\ell}} s_{\ell \end{aligned} ``` """ -function RHN((input_size, hidden_size)::Pair, depth=3; kwargs...) +function RHN((input_size, hidden_size)::Pair, depth::Integer=3; kwargs...) cell = RHNCell(input_size => hidden_size, depth; kwargs...) return RHN(cell) end + +function initialstates(rhn::RHN) + return initialstates(rhn.cell) +end -function (rhn::RHN)(inp) - state = zeros_like(inp, size(rhn.cell.layers[2].weights, 2)) +function (rhn::RHN)(inp::AbstractArray) + state = initialstates(rhn) return rhn(inp, state) end -function (rhn::RHN)(inp, state) +function (rhn::RHN)(inp::AbstractArray, state::AbstractVecOrMat) @assert ndims(inp) == 2 || ndims(inp) == 3 new_state = [] for inp_t in eachslice(inp, dims=2) From 3b4d8d4ad527f88b4e13d800d08eeef2c9d95647 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Thu, 19 Dec 2024 11:15:34 +0100 Subject: [PATCH 2/2] final fixes --- Project.toml | 2 +- src/rhn_cell.jl | 13 ++++--------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index 620cc1a..812dab4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RecurrentLayers" uuid = "78449bcf-6750-4b78-9e82-63d4a1ccdf8c" authors = ["Francesco Martinuzzi"] -version = "0.2.2" +version = "0.2.3" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/rhn_cell.jl b/src/rhn_cell.jl index 62b52c1..68d0eba 100644 --- a/src/rhn_cell.jl +++ b/src/rhn_cell.jl @@ -31,8 +31,8 @@ function (rhn::RHNCellUnit)(inp::AbstractVecOrMat) end function (rhn::RHNCellUnit)(inp::AbstractVecOrMat, state::AbstractVecOrMat) - _size_check(rhn, inp, 1 => size(rhn.weight, 2)) - weight, bias = rhn.weight, rhn.bias + _size_check(rhn, inp, 1 => size(rhn.weights, 2)) + weight, bias = rhn.weights, rhn.bias #compute pre_nonlin = weight * inp .+ bias @@ -43,7 +43,7 @@ function (rhn::RHNCellUnit)(inp::AbstractVecOrMat, state::AbstractVecOrMat) end Base.show(io::IO, rhn::RHNCellUnit) = - print(io, "RHNCellUnit(", size(rhn.weight, 2), " => ", size(rhn.weight, 1)÷3, ")") + print(io, "RHNCellUnit(", size(rhn.weights, 2), " => ", size(rhn.weights, 1)÷3, ")") struct RHNCell{C} layers::C @@ -189,10 +189,5 @@ end function (rhn::RHN)(inp::AbstractArray, state::AbstractVecOrMat) @assert ndims(inp) == 2 || ndims(inp) == 3 - new_state = [] - for inp_t in eachslice(inp, dims=2) - state = rhn.cell(inp_t, state) - new_state = vcat(new_state, [state]) - end - return stack(new_state, dims=2) + return scan(rhn.cell, inp, state) end \ No newline at end of file