diff --git a/src/rhn_cell.jl b/src/rhn_cell.jl index 6a5866b..62b52c1 100644 --- a/src/rhn_cell.jl +++ b/src/rhn_cell.jl @@ -2,7 +2,7 @@ #https://github.com/jzilly/RecurrentHighwayNetworks/blob/master/rhn.py#L138C1-L180C60 struct RHNCellUnit{I,V} - weight::I + weights::I bias::V end @@ -15,18 +15,22 @@ Flux.@layer RHNCellUnit """ function RHNCellUnit((input_size, hidden_size)::Pair; init_kernel = glorot_uniform, - bias = true) + bias::Bool = true) weight = init_kernel(3 * hidden_size, input_size) b = create_bias(weight, bias, size(weight, 1)) return RHNCellUnit(weight, b) end +function initialstates(rhn::RHNCellUnit) + return zeros_like(rhn.weights, size(rhn.weights, 1) รท 3) +end + function (rhn::RHNCellUnit)(inp::AbstractVecOrMat) - state = zeros_like(inp, size(rhn.weight, 2)) + state = initialstates(rhn) return rhn(inp, state) end -function (rhn::RHNCellUnit)(inp::AbstractVecOrMat, state) +function (rhn::RHNCellUnit)(inp::AbstractVecOrMat, state::AbstractVecOrMat) _size_check(rhn, inp, 1 => size(rhn.weight, 2)) weight, bias = rhn.weight, rhn.bias @@ -81,7 +85,7 @@ c_{\ell}^{[t]} &= \sigma(W_c x^{[t]}\mathbb{I}_{\ell = 1} + U_{c_{\ell}} s_{\ell rnncell(inp, [state]) """ -function RHNCell((input_size, hidden_size), depth::Int = 3; +function RHNCell((input_size, hidden_size), depth::Integer = 3; couple_carry::Bool = true, #sec 5, setup cell_kwargs...) @@ -98,13 +102,16 @@ function RHNCell((input_size, hidden_size), depth::Int = 3; return RHNCell(Chain(layers), couple_carry) end +function initialstates(rhn::RHNCell) + return initialstates(first(rhn.layers)) +end -function (rhn::RHNCell)(inp, state=nothing) +function (rhn::RHNCell)(inp::AbstractArray) + state = initialstates(rhn) + return rhn(inp, state) +end - #not ideal - if state == nothing - state = zeros_like(inp, size(rhn.layers.layers[2].weight, 2)) - end +function (rhn::RHNCell)(inp::AbstractArray, state::AbstractVecOrMat) current_state = state @@ -141,7 +148,7 @@ end Flux.@layer :noexpand RHN @doc raw""" - RHN((input_size => hidden_size)::Pair depth=3; kwargs...) + RHN((input_size => hidden_size) depth=3; kwargs...) [Recurrent highway network](https://arxiv.org/pdf/1607.03474). See [`RHNCellUnit`](@ref) for a the unit component of this layer. @@ -166,17 +173,21 @@ c_{\ell}^{[t]} &= \sigma(W_c x^{[t]}\mathbb{I}_{\ell = 1} + U_{c_{\ell}} s_{\ell \end{aligned} ``` """ -function RHN((input_size, hidden_size)::Pair, depth=3; kwargs...) +function RHN((input_size, hidden_size)::Pair, depth::Integer=3; kwargs...) cell = RHNCell(input_size => hidden_size, depth; kwargs...) return RHN(cell) end + +function initialstates(rhn::RHN) + return initialstates(rhn.cell) +end -function (rhn::RHN)(inp) - state = zeros_like(inp, size(rhn.cell.layers[2].weights, 2)) +function (rhn::RHN)(inp::AbstractArray) + state = initialstates(rhn) return rhn(inp, state) end -function (rhn::RHN)(inp, state) +function (rhn::RHN)(inp::AbstractArray, state::AbstractVecOrMat) @assert ndims(inp) == 2 || ndims(inp) == 3 new_state = [] for inp_t in eachslice(inp, dims=2)