diff --git a/docs/pages.jl b/docs/pages.jl index de181af..ce4d3cb 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -3,6 +3,7 @@ pages=[ "API Documentation" => [ "Cells" => "api/cells.md", "Layers" => "api/layers.md", + "Wrappers" => "api/wrappers.md", ], "Roadmap" => "roadmap.md" ] \ No newline at end of file diff --git a/docs/src/api/layers.md b/docs/src/api/layers.md index a7f3bbe..f408b3d 100644 --- a/docs/src/api/layers.md +++ b/docs/src/api/layers.md @@ -1,4 +1,4 @@ -# Cell wrappers +# Layers ```@docs RAN diff --git a/docs/src/api/wrappers.md b/docs/src/api/wrappers.md new file mode 100644 index 0000000..9c04cfc --- /dev/null +++ b/docs/src/api/wrappers.md @@ -0,0 +1,5 @@ +# Wrappers + +```@docs +StackedRNN +``` \ No newline at end of file diff --git a/src/RecurrentLayers.jl b/src/RecurrentLayers.jl index cec596b..ed97b53 100644 --- a/src/RecurrentLayers.jl +++ b/src/RecurrentLayers.jl @@ -45,21 +45,26 @@ end export MGUCell, LiGRUCell, IndRNNCell, RANCell, LightRUCell, RHNCell, RHNCellUnit, NASCell, MUT1Cell, MUT2Cell, MUT3Cell, SCRNCell, PeepholeLSTMCell, FastRNNCell, FastGRNNCell + export MGU, LiGRU, IndRNN, RAN, LightRU, NAS, RHN, MUT1, MUT2, MUT3, SCRN, PeepholeLSTM, FastRNN, FastGRNN +export StackedRNN + @compat(public, (initialstates)) -include("mgu_cell.jl") -include("ligru_cell.jl") -include("indrnn_cell.jl") -include("ran_cell.jl") -include("lightru_cell.jl") -include("rhn_cell.jl") -include("nas_cell.jl") -include("mut_cell.jl") -include("scrn_cell.jl") -include("peepholelstm_cell.jl") -include("fastrnn_cell.jl") +include("cells/mgu_cell.jl") +include("cells/ligru_cell.jl") +include("cells/indrnn_cell.jl") +include("cells/ran_cell.jl") +include("cells/lightru_cell.jl") +include("cells/rhn_cell.jl") +include("cells/nas_cell.jl") +include("cells/mut_cell.jl") +include("cells/scrn_cell.jl") +include("cells/peepholelstm_cell.jl") +include("cells/fastrnn_cell.jl") + +include("wrappers/stackedrnn.jl") end #module \ No newline at end of file diff --git a/src/fastrnn_cell.jl b/src/cells/fastrnn_cell.jl similarity index 100% rename from src/fastrnn_cell.jl rename to src/cells/fastrnn_cell.jl diff --git a/src/indrnn_cell.jl b/src/cells/indrnn_cell.jl similarity index 100% rename from src/indrnn_cell.jl rename to src/cells/indrnn_cell.jl diff --git a/src/lightru_cell.jl b/src/cells/lightru_cell.jl similarity index 100% rename from src/lightru_cell.jl rename to src/cells/lightru_cell.jl diff --git a/src/ligru_cell.jl b/src/cells/ligru_cell.jl similarity index 100% rename from src/ligru_cell.jl rename to src/cells/ligru_cell.jl diff --git a/src/mgu_cell.jl b/src/cells/mgu_cell.jl similarity index 100% rename from src/mgu_cell.jl rename to src/cells/mgu_cell.jl diff --git a/src/mut_cell.jl b/src/cells/mut_cell.jl similarity index 100% rename from src/mut_cell.jl rename to src/cells/mut_cell.jl diff --git a/src/nas_cell.jl b/src/cells/nas_cell.jl similarity index 100% rename from src/nas_cell.jl rename to src/cells/nas_cell.jl diff --git a/src/peepholelstm_cell.jl b/src/cells/peepholelstm_cell.jl similarity index 100% rename from src/peepholelstm_cell.jl rename to src/cells/peepholelstm_cell.jl diff --git a/src/ran_cell.jl b/src/cells/ran_cell.jl similarity index 100% rename from src/ran_cell.jl rename to src/cells/ran_cell.jl diff --git a/src/rhn_cell.jl b/src/cells/rhn_cell.jl similarity index 100% rename from src/rhn_cell.jl rename to src/cells/rhn_cell.jl diff --git a/src/scrn_cell.jl b/src/cells/scrn_cell.jl similarity index 100% rename from src/scrn_cell.jl rename to src/cells/scrn_cell.jl diff --git a/src/sru_cell.jl b/src/cells/sru_cell.jl similarity index 100% rename from src/sru_cell.jl rename to src/cells/sru_cell.jl diff --git a/src/wrappers/stackedrnn.jl b/src/wrappers/stackedrnn.jl new file mode 100644 index 0000000..b200849 --- /dev/null +++ b/src/wrappers/stackedrnn.jl @@ -0,0 +1,61 @@ +# based on https://fluxml.ai/Flux.jl/stable/guide/models/recurrence/ +struct StackedRNN{L,D,S} + layers::L + dropout::D + states::S +end + +Flux.@layer StackedRNN trainable=(layers) + +@doc raw""" + StackedRNN(rlayer, (input_size, hidden_size), args...; + num_layers = 1, kwargs...) + +Constructs a stack of recurrent layers given the recurrent layer type. + +Arguments: + - `rlayer`: Any recurrent layer such as [MGU](@ref), [RHN](@ref), etc... or + [`Flux.RNN`](@extref), [`Flux.LSTM`](@extref), etc. + - `input_size`: Defines the input dimension for the first layer. + - `hidden_size`: defines the dimension of the hidden layer. + - `num_layers`: The number of layers to stack. Default is 1. + - `args...`: Additional positional arguments passed to the recurrent layer. + - `kwargs...`: Additional keyword arguments passed to the recurrent layers. + +Returns: + A `StackedRNN` instance containing the specified number of RNN layers and their initial states. +""" +function StackedRNN(rlayer, (input_size, hidden_size)::Pair, args...; + num_layers::Int = 1, + dropout::Number = 0.0, + dims = :, + active::Union{Bool,Nothing} = nothing, + rng = Flux.default_rng(), + kwargs...) + #build container + layers = [] + #warn for dropout and num_layers + if num_layers ==1 && dropout != 0.0 + @warn("Dropout is not applied when num_layers is 1.") + end + + for idx in 1:num_layers + in_size = idx == 1 ? input_size : hidden_size + push!(layers, rlayer(in_size => hidden_size, args...; kwargs...)) + end + states = [initialstates(layer) for layer in layers] + + return StackedRNN(layers, + Dropout(dropout; dims = dims, active = active, rng = rng), + states) +end + +function (stackedrnn::StackedRNN)(inp::AbstractArray) + for (idx,(layer, state)) in enumerate(zip(stackedrnn.layers, stackedrnn.states)) + inp = layer(inp, state) + if !(idx == length(stackedrnn.layers)) + inp = stackedrnn.dropout(inp) + end + end + return inp +end diff --git a/test/runtests.jl b/test/runtests.jl index c63188a..b86257c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,4 +11,8 @@ end @safetestset "Layers" begin include("test_layers.jl") +end + +@safetestset "Wrappers" begin + include("test_wrappers.jl") end \ No newline at end of file diff --git a/test/test_wrappers.jl b/test/test_wrappers.jl new file mode 100644 index 0000000..93aa354 --- /dev/null +++ b/test/test_wrappers.jl @@ -0,0 +1,20 @@ +using RecurrentLayers +using Flux +using Test + +layers = [RNN, GRU, GRUv3, LSTM, MGU, LiGRU, RAN, LightRU, NAS, MUT1, MUT2, MUT3, +SCRN, PeepholeLSTM, FastRNN, FastGRNN] + +@testset "Sizes for StackedRNN with layer: $layer" for layer in layers + wrap = StackedRNN(layer, 2 => 4) + + inp = rand(Float32, 2, 3, 1) + output = wrap(inp) + @test output isa Array{Float32, 3} + @test size(output) == (4, 3, 1) + + inp = rand(Float32, 2, 3) + output = wrap(inp) + @test output isa Array{Float32, 2} + @test size(output) == (4, 3) +end \ No newline at end of file