Skip to content

Commit

Permalink
rhn fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco committed Dec 18, 2024
1 parent 16908bf commit 900869e
Showing 1 changed file with 26 additions and 15 deletions.
41 changes: 26 additions & 15 deletions src/rhn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#https://github.com/jzilly/RecurrentHighwayNetworks/blob/master/rhn.py#L138C1-L180C60

struct RHNCellUnit{I,V}
weight::I
weights::I
bias::V
end

Expand All @@ -15,18 +15,22 @@ Flux.@layer RHNCellUnit
"""
function RHNCellUnit((input_size, hidden_size)::Pair;
init_kernel = glorot_uniform,
bias = true)
bias::Bool = true)
weight = init_kernel(3 * hidden_size, input_size)
b = create_bias(weight, bias, size(weight, 1))
return RHNCellUnit(weight, b)
end

function initialstates(rhn::RHNCellUnit)
return zeros_like(rhn.weights, size(rhn.weights, 1) ÷ 3)
end

function (rhn::RHNCellUnit)(inp::AbstractVecOrMat)
state = zeros_like(inp, size(rhn.weight, 2))
state = initialstates(rhn)
return rhn(inp, state)
end

function (rhn::RHNCellUnit)(inp::AbstractVecOrMat, state)
function (rhn::RHNCellUnit)(inp::AbstractVecOrMat, state::AbstractVecOrMat)
_size_check(rhn, inp, 1 => size(rhn.weight, 2))
weight, bias = rhn.weight, rhn.bias

Expand Down Expand Up @@ -81,7 +85,7 @@ c_{\ell}^{[t]} &= \sigma(W_c x^{[t]}\mathbb{I}_{\ell = 1} + U_{c_{\ell}} s_{\ell
rnncell(inp, [state])
"""
function RHNCell((input_size, hidden_size), depth::Int = 3;
function RHNCell((input_size, hidden_size), depth::Integer = 3;
couple_carry::Bool = true, #sec 5, setup
cell_kwargs...)

Expand All @@ -98,13 +102,16 @@ function RHNCell((input_size, hidden_size), depth::Int = 3;
return RHNCell(Chain(layers), couple_carry)
end

function initialstates(rhn::RHNCell)
return initialstates(first(rhn.layers))
end

function (rhn::RHNCell)(inp, state=nothing)
function (rhn::RHNCell)(inp::AbstractArray)
state = initialstates(rhn)
return rhn(inp, state)
end

#not ideal
if state == nothing
state = zeros_like(inp, size(rhn.layers.layers[2].weight, 2))
end
function (rhn::RHNCell)(inp::AbstractArray, state::AbstractVecOrMat)

current_state = state

Expand Down Expand Up @@ -141,7 +148,7 @@ end
Flux.@layer :noexpand RHN

@doc raw"""
RHN((input_size => hidden_size)::Pair depth=3; kwargs...)
RHN((input_size => hidden_size) depth=3; kwargs...)
[Recurrent highway network](https://arxiv.org/pdf/1607.03474).
See [`RHNCellUnit`](@ref) for a the unit component of this layer.
Expand All @@ -166,17 +173,21 @@ c_{\ell}^{[t]} &= \sigma(W_c x^{[t]}\mathbb{I}_{\ell = 1} + U_{c_{\ell}} s_{\ell
\end{aligned}
```
"""
function RHN((input_size, hidden_size)::Pair, depth=3; kwargs...)
function RHN((input_size, hidden_size)::Pair, depth::Integer=3; kwargs...)
cell = RHNCell(input_size => hidden_size, depth; kwargs...)
return RHN(cell)
end

function initialstates(rhn::RHN)
return initialstates(rhn.cell)
end

function (rhn::RHN)(inp)
state = zeros_like(inp, size(rhn.cell.layers[2].weights, 2))
function (rhn::RHN)(inp::AbstractArray)
state = initialstates(rhn)
return rhn(inp, state)
end

function (rhn::RHN)(inp, state)
function (rhn::RHN)(inp::AbstractArray, state::AbstractVecOrMat)
@assert ndims(inp) == 2 || ndims(inp) == 3
new_state = []
for inp_t in eachslice(inp, dims=2)
Expand Down

0 comments on commit 900869e

Please sign in to comment.