From 9b3dd1573a23a71cd71d6061c628c229356dbf91 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Fri, 1 Nov 2024 14:17:48 +0100 Subject: [PATCH] testing multilayer mgu --- Project.toml | 2 ++ src/RecurrentLayers.jl | 2 ++ src/mgu_cell.jl | 46 ++++++++++++++++++++++++++---------------- src/utils.jl | 9 +++++++++ 4 files changed, 42 insertions(+), 17 deletions(-) create mode 100644 src/utils.jl diff --git a/Project.toml b/Project.toml index 5491317..2b7d16a 100644 --- a/Project.toml +++ b/Project.toml @@ -5,9 +5,11 @@ version = "0.1.0" [deps] Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Flux = "0.14" +Zygote = "0.6.72" julia = "1.10" [extras] diff --git a/src/RecurrentLayers.jl b/src/RecurrentLayers.jl index 0102af5..891434b 100644 --- a/src/RecurrentLayers.jl +++ b/src/RecurrentLayers.jl @@ -17,4 +17,6 @@ include("lightru_cell.jl") include("rhn_cell.jl") include("nas_cell.jl") +include("utils.jl") + end #module diff --git a/src/mgu_cell.jl b/src/mgu_cell.jl index 32f7e64..e165abb 100644 --- a/src/mgu_cell.jl +++ b/src/mgu_cell.jl @@ -46,30 +46,42 @@ Base.show(io::IO, mgu::MGUCell) = struct MGU{M} - cell::M + cells::Vector{M} + dropout::Real end - + Flux.@layer :expand MGU """ - MGU((in, out)::Pair; init = glorot_uniform, bias = true) + MGU((in_size, out_size)::Pair; n_layers = 1, dropout = 0.0, init = glorot_uniform, bias = true) """ -function MGU((in, out)::Pair; init = glorot_uniform, bias = true) - cell = MGUCell(in => out; init, bias) - return MGU(cell) +function MGU((in_size, out_size)::Pair; + n_layers::Int=1, + dropout::Float64=0.0, + kwargs...) + cells = [] + for i in 1:n_layers + tin_size = i == 1 ? in_size : out_size + push!(cells, MGUCell(tin_size => out_size; kwargs...)) + end + return MGU(cells, dropout) end -function (mgu::MGU)(inp) - state = zeros_like(inp, size(mgu.cell.Wh, 2)) - return mgu(inp, state) +# Forward pass without initial state +function (mgu::MGU)(input) + batch_size = size(input, 3) + state = [zeros(size(mgu.cells[i].Wh, 2), batch_size) for i in 1:length(mgu.cells)] + return mgu(input, state) end - -function (mgu::MGU)(inp, state) + +function (mgu::MGU)(inp, initial_states) @assert ndims(inp) == 2 || ndims(inp) == 3 - new_state = [] - for inp_t in eachslice(inp, dims=2) - state = mgu.cell(inp_t, state) - new_state = vcat(new_state, [state]) - end - return stack(new_state, dims=2) + num_layers = length(mgu.cells) + foldl((acc, idx) -> begin + (layer_input, states) = acc + cell = mgu.cells[idx] + layer_output, new_state = _process_layer(layer_input, states[idx], cell) + updated_states = vcat(states[1:idx-1], [new_state], states[idx+1:end]) + return layer_output, updated_states + end, 1:num_layers, init=(inp, initial_states))[1] end diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..57ffcb7 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,9 @@ +function _process_layer(layer_input, state, cell) + new_states = map(eachslice(layer_input, dims=2)) do inp_t + state = cell(inp_t, state) + state + end + + layer_output = stack(new_states, dims=2) + return layer_output, state +end \ No newline at end of file