From 6795bc4c3771b1638f6c9f1b6b5822f595c4d3dc Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Fri, 20 Dec 2024 10:16:36 +0100 Subject: [PATCH 1/7] adding stackedrnn --- src/RecurrentLayers.jl | 27 ++++++++++------- src/{ => cells}/fastrnn_cell.jl | 0 src/{ => cells}/indrnn_cell.jl | 0 src/{ => cells}/lightru_cell.jl | 0 src/{ => cells}/ligru_cell.jl | 0 src/{ => cells}/mgu_cell.jl | 0 src/{ => cells}/mut_cell.jl | 0 src/{ => cells}/nas_cell.jl | 0 src/{ => cells}/peepholelstm_cell.jl | 0 src/{ => cells}/ran_cell.jl | 0 src/{ => cells}/rhn_cell.jl | 0 src/{ => cells}/scrn_cell.jl | 0 src/{ => cells}/sru_cell.jl | 0 src/wrappers/stackedrnn.jl | 45 ++++++++++++++++++++++++++++ 14 files changed, 61 insertions(+), 11 deletions(-) rename src/{ => cells}/fastrnn_cell.jl (100%) rename src/{ => cells}/indrnn_cell.jl (100%) rename src/{ => cells}/lightru_cell.jl (100%) rename src/{ => cells}/ligru_cell.jl (100%) rename src/{ => cells}/mgu_cell.jl (100%) rename src/{ => cells}/mut_cell.jl (100%) rename src/{ => cells}/nas_cell.jl (100%) rename src/{ => cells}/peepholelstm_cell.jl (100%) rename src/{ => cells}/ran_cell.jl (100%) rename src/{ => cells}/rhn_cell.jl (100%) rename src/{ => cells}/scrn_cell.jl (100%) rename src/{ => cells}/sru_cell.jl (100%) create mode 100644 src/wrappers/stackedrnn.jl diff --git a/src/RecurrentLayers.jl b/src/RecurrentLayers.jl index cec596b..ed97b53 100644 --- a/src/RecurrentLayers.jl +++ b/src/RecurrentLayers.jl @@ -45,21 +45,26 @@ end export MGUCell, LiGRUCell, IndRNNCell, RANCell, LightRUCell, RHNCell, RHNCellUnit, NASCell, MUT1Cell, MUT2Cell, MUT3Cell, SCRNCell, PeepholeLSTMCell, FastRNNCell, FastGRNNCell + export MGU, LiGRU, IndRNN, RAN, LightRU, NAS, RHN, MUT1, MUT2, MUT3, SCRN, PeepholeLSTM, FastRNN, FastGRNN +export StackedRNN + @compat(public, (initialstates)) -include("mgu_cell.jl") -include("ligru_cell.jl") -include("indrnn_cell.jl") -include("ran_cell.jl") -include("lightru_cell.jl") -include("rhn_cell.jl") -include("nas_cell.jl") -include("mut_cell.jl") -include("scrn_cell.jl") -include("peepholelstm_cell.jl") -include("fastrnn_cell.jl") +include("cells/mgu_cell.jl") +include("cells/ligru_cell.jl") +include("cells/indrnn_cell.jl") +include("cells/ran_cell.jl") +include("cells/lightru_cell.jl") +include("cells/rhn_cell.jl") +include("cells/nas_cell.jl") +include("cells/mut_cell.jl") +include("cells/scrn_cell.jl") +include("cells/peepholelstm_cell.jl") +include("cells/fastrnn_cell.jl") + +include("wrappers/stackedrnn.jl") end #module \ No newline at end of file diff --git a/src/fastrnn_cell.jl b/src/cells/fastrnn_cell.jl similarity index 100% rename from src/fastrnn_cell.jl rename to src/cells/fastrnn_cell.jl diff --git a/src/indrnn_cell.jl b/src/cells/indrnn_cell.jl similarity index 100% rename from src/indrnn_cell.jl rename to src/cells/indrnn_cell.jl diff --git a/src/lightru_cell.jl b/src/cells/lightru_cell.jl similarity index 100% rename from src/lightru_cell.jl rename to src/cells/lightru_cell.jl diff --git a/src/ligru_cell.jl b/src/cells/ligru_cell.jl similarity index 100% rename from src/ligru_cell.jl rename to src/cells/ligru_cell.jl diff --git a/src/mgu_cell.jl b/src/cells/mgu_cell.jl similarity index 100% rename from src/mgu_cell.jl rename to src/cells/mgu_cell.jl diff --git a/src/mut_cell.jl b/src/cells/mut_cell.jl similarity index 100% rename from src/mut_cell.jl rename to src/cells/mut_cell.jl diff --git a/src/nas_cell.jl b/src/cells/nas_cell.jl similarity index 100% rename from src/nas_cell.jl rename to src/cells/nas_cell.jl diff --git a/src/peepholelstm_cell.jl b/src/cells/peepholelstm_cell.jl similarity index 100% rename from src/peepholelstm_cell.jl rename to src/cells/peepholelstm_cell.jl diff --git a/src/ran_cell.jl b/src/cells/ran_cell.jl similarity index 100% rename from src/ran_cell.jl rename to src/cells/ran_cell.jl diff --git a/src/rhn_cell.jl b/src/cells/rhn_cell.jl similarity index 100% rename from src/rhn_cell.jl rename to src/cells/rhn_cell.jl diff --git a/src/scrn_cell.jl b/src/cells/scrn_cell.jl similarity index 100% rename from src/scrn_cell.jl rename to src/cells/scrn_cell.jl diff --git a/src/sru_cell.jl b/src/cells/sru_cell.jl similarity index 100% rename from src/sru_cell.jl rename to src/cells/sru_cell.jl diff --git a/src/wrappers/stackedrnn.jl b/src/wrappers/stackedrnn.jl new file mode 100644 index 0000000..55bf6d2 --- /dev/null +++ b/src/wrappers/stackedrnn.jl @@ -0,0 +1,45 @@ +# based on https://fluxml.ai/Flux.jl/stable/guide/models/recurrence/ +struct StackedRNN{L,S} + layers::L + states::S +end + +Flux.@layer StackedRNN + +""" + StackedRNN(rlayer, (input_size, hidden_size), args...; + num_layers = 1, kwargs...) + +Constructs a stack of recurrent layers given the recurrent layer type. + +Arguments: + - `rlayer`: Any recurrent layer such as [MGU](@ref), [RHN](@ref), etc... or + [RNN](@extref), [LSTM](@extref), etc... + - `input_size`: Defines the input dimension for the first layer. + - `hidden_size`: defines the dimension of the hidden layer. + - `num_layers`: The number of layers to stack. Default is 1. + - `args...`: Additional positional arguments passed to the recurrent layer. + - `kwargs...`: Additional keyword arguments passed to the recurrent layers. + +Returns: + A `StackedRNN` instance containing the specified number of RNN layers and their initial states. +""" +function StackedRNN(rlayer, (input_size, hidden_size)::Pair, args...; + num_layers::Int = 1, + kwargs...) + layers = [] + for (idx,layer) in enumerate(num_layers) + in_size = idx == 1 ? input_size : hidden_size + push!(layers, rlayer(in_size => hidden_size, args...; kwargs...)) + end + states = [initialstates(layer) for layer in layers] + + return StackedRNN(layers, states0) +end + +function (stackedrnn::StackedRNN)(inp::AbstracArray) + for (layer, state) in zip(stackedrnn.layers, stackedrnn.states) + inp = layer(inp, state0) + end + return inp +end From 39339d504e60d80562af10714a8d80b0c2f133fa Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Fri, 20 Dec 2024 10:47:18 +0100 Subject: [PATCH 2/7] adding dropout to stackedrnn --- src/wrappers/stackedrnn.jl | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/wrappers/stackedrnn.jl b/src/wrappers/stackedrnn.jl index 55bf6d2..4782754 100644 --- a/src/wrappers/stackedrnn.jl +++ b/src/wrappers/stackedrnn.jl @@ -1,6 +1,7 @@ # based on https://fluxml.ai/Flux.jl/stable/guide/models/recurrence/ -struct StackedRNN{L,S} +struct StackedRNN{L,D,S} layers::L + droput::D states::S end @@ -14,7 +15,8 @@ Constructs a stack of recurrent layers given the recurrent layer type. Arguments: - `rlayer`: Any recurrent layer such as [MGU](@ref), [RHN](@ref), etc... or - [RNN](@extref), [LSTM](@extref), etc... + [Flux.RNN](@extref), [Flux.LSTM](@extref), etc... Additionally anything wrapped in + [Flux.recurrence](@extref) can be used as `rlayer`. - `input_size`: Defines the input dimension for the first layer. - `hidden_size`: defines the dimension of the hidden layer. - `num_layers`: The number of layers to stack. Default is 1. @@ -26,6 +28,7 @@ Returns: """ function StackedRNN(rlayer, (input_size, hidden_size)::Pair, args...; num_layers::Int = 1, + dropout::Number = 0.0, kwargs...) layers = [] for (idx,layer) in enumerate(num_layers) @@ -34,12 +37,15 @@ function StackedRNN(rlayer, (input_size, hidden_size)::Pair, args...; end states = [initialstates(layer) for layer in layers] - return StackedRNN(layers, states0) + return StackedRNN(layers, Dropout(dropout), states) end function (stackedrnn::StackedRNN)(inp::AbstracArray) - for (layer, state) in zip(stackedrnn.layers, stackedrnn.states) - inp = layer(inp, state0) - end - return inp + for (idx,(layer, state)) in enumerate(zip(stackedrnn.layers, stackedrnn.states)) + inp = layer(inp, state0) + if !(idx == length(stackedrnn.layers)) + inp = stackedrnn.dropout(inp) + end + end + return inp end From 13fbb039280924f35e8c4afffbb9dc149e87548a Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sat, 21 Dec 2024 19:19:28 +0100 Subject: [PATCH 3/7] small fixes to stackdernn --- src/wrappers/stackedrnn.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/wrappers/stackedrnn.jl b/src/wrappers/stackedrnn.jl index 4782754..9f0a05f 100644 --- a/src/wrappers/stackedrnn.jl +++ b/src/wrappers/stackedrnn.jl @@ -5,7 +5,7 @@ struct StackedRNN{L,D,S} states::S end -Flux.@layer StackedRNN +Flux.@layer StackedRNN trainable=(layers) """ StackedRNN(rlayer, (input_size, hidden_size), args...; @@ -40,7 +40,7 @@ function StackedRNN(rlayer, (input_size, hidden_size)::Pair, args...; return StackedRNN(layers, Dropout(dropout), states) end -function (stackedrnn::StackedRNN)(inp::AbstracArray) +function (stackedrnn::StackedRNN)(inp::AbstractArray) for (idx,(layer, state)) in enumerate(zip(stackedrnn.layers, stackedrnn.states)) inp = layer(inp, state0) if !(idx == length(stackedrnn.layers)) From 92f8d02c6dad618346a3c0e182ea1556838e97cb Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sat, 21 Dec 2024 19:29:52 +0100 Subject: [PATCH 4/7] docs --- docs/pages.jl | 1 + docs/src/api/layers.md | 2 +- docs/src/api/wrappers.md | 5 +++++ 3 files changed, 7 insertions(+), 1 deletion(-) create mode 100644 docs/src/api/wrappers.md diff --git a/docs/pages.jl b/docs/pages.jl index de181af..ce4d3cb 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -3,6 +3,7 @@ pages=[ "API Documentation" => [ "Cells" => "api/cells.md", "Layers" => "api/layers.md", + "Wrappers" => "api/wrappers.md", ], "Roadmap" => "roadmap.md" ] \ No newline at end of file diff --git a/docs/src/api/layers.md b/docs/src/api/layers.md index a7f3bbe..f408b3d 100644 --- a/docs/src/api/layers.md +++ b/docs/src/api/layers.md @@ -1,4 +1,4 @@ -# Cell wrappers +# Layers ```@docs RAN diff --git a/docs/src/api/wrappers.md b/docs/src/api/wrappers.md new file mode 100644 index 0000000..9c04cfc --- /dev/null +++ b/docs/src/api/wrappers.md @@ -0,0 +1,5 @@ +# Wrappers + +```@docs +StackedRNN +``` \ No newline at end of file From 5c10776ea94a7dc116259176a586615caa4cd922 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sat, 21 Dec 2024 21:26:59 +0100 Subject: [PATCH 5/7] adding kwargs to dropout and warning --- src/wrappers/stackedrnn.jl | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/wrappers/stackedrnn.jl b/src/wrappers/stackedrnn.jl index 9f0a05f..c74a636 100644 --- a/src/wrappers/stackedrnn.jl +++ b/src/wrappers/stackedrnn.jl @@ -7,7 +7,7 @@ end Flux.@layer StackedRNN trainable=(layers) -""" +@doc raw""" StackedRNN(rlayer, (input_size, hidden_size), args...; num_layers = 1, kwargs...) @@ -15,8 +15,8 @@ Constructs a stack of recurrent layers given the recurrent layer type. Arguments: - `rlayer`: Any recurrent layer such as [MGU](@ref), [RHN](@ref), etc... or - [Flux.RNN](@extref), [Flux.LSTM](@extref), etc... Additionally anything wrapped in - [Flux.recurrence](@extref) can be used as `rlayer`. + [`Flux.RNN`](@extref), [`Flux.LSTM`](@extref), etc... Additionally anything wrapped in + [`Flux.recurrence`](@extref) can be used as `rlayer`. - `input_size`: Defines the input dimension for the first layer. - `hidden_size`: defines the dimension of the hidden layer. - `num_layers`: The number of layers to stack. Default is 1. @@ -29,15 +29,26 @@ Returns: function StackedRNN(rlayer, (input_size, hidden_size)::Pair, args...; num_layers::Int = 1, dropout::Number = 0.0, + dims = :, + active::Union{Bool,Nothing} = nothing, + rng = Flux.default_rng(), kwargs...) + #build container layers = [] + #warn for dropout and num_layers + if num_layers ==1 && dropout != 0.0 + @warn("Dropout is not applied when num_layers is 1.") + end + for (idx,layer) in enumerate(num_layers) in_size = idx == 1 ? input_size : hidden_size push!(layers, rlayer(in_size => hidden_size, args...; kwargs...)) end states = [initialstates(layer) for layer in layers] - return StackedRNN(layers, Dropout(dropout), states) + return StackedRNN(layers, + Dropout(dropout; dims = dims, active = active, rng = rng), + states) end function (stackedrnn::StackedRNN)(inp::AbstractArray) From 0ee418bc6cbdfe68a6db5c1d8188531a7ec10dd8 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sun, 22 Dec 2024 14:55:25 +0100 Subject: [PATCH 6/7] trying to build docs --- src/wrappers/stackedrnn.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/wrappers/stackedrnn.jl b/src/wrappers/stackedrnn.jl index c74a636..58ecb77 100644 --- a/src/wrappers/stackedrnn.jl +++ b/src/wrappers/stackedrnn.jl @@ -15,8 +15,7 @@ Constructs a stack of recurrent layers given the recurrent layer type. Arguments: - `rlayer`: Any recurrent layer such as [MGU](@ref), [RHN](@ref), etc... or - [`Flux.RNN`](@extref), [`Flux.LSTM`](@extref), etc... Additionally anything wrapped in - [`Flux.recurrence`](@extref) can be used as `rlayer`. + [`Flux.RNN`](@extref), [`Flux.LSTM`](@extref), etc. - `input_size`: Defines the input dimension for the first layer. - `hidden_size`: defines the dimension of the hidden layer. - `num_layers`: The number of layers to stack. Default is 1. From 36c585e715f82696bafb55025c26e475ef28deab Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sun, 22 Dec 2024 15:21:29 +0100 Subject: [PATCH 7/7] tests and fixes for stackedrnn --- src/wrappers/stackedrnn.jl | 6 +++--- test/runtests.jl | 4 ++++ test/test_wrappers.jl | 20 ++++++++++++++++++++ 3 files changed, 27 insertions(+), 3 deletions(-) create mode 100644 test/test_wrappers.jl diff --git a/src/wrappers/stackedrnn.jl b/src/wrappers/stackedrnn.jl index 58ecb77..b200849 100644 --- a/src/wrappers/stackedrnn.jl +++ b/src/wrappers/stackedrnn.jl @@ -1,7 +1,7 @@ # based on https://fluxml.ai/Flux.jl/stable/guide/models/recurrence/ struct StackedRNN{L,D,S} layers::L - droput::D + dropout::D states::S end @@ -39,7 +39,7 @@ function StackedRNN(rlayer, (input_size, hidden_size)::Pair, args...; @warn("Dropout is not applied when num_layers is 1.") end - for (idx,layer) in enumerate(num_layers) + for idx in 1:num_layers in_size = idx == 1 ? input_size : hidden_size push!(layers, rlayer(in_size => hidden_size, args...; kwargs...)) end @@ -52,7 +52,7 @@ end function (stackedrnn::StackedRNN)(inp::AbstractArray) for (idx,(layer, state)) in enumerate(zip(stackedrnn.layers, stackedrnn.states)) - inp = layer(inp, state0) + inp = layer(inp, state) if !(idx == length(stackedrnn.layers)) inp = stackedrnn.dropout(inp) end diff --git a/test/runtests.jl b/test/runtests.jl index c63188a..b86257c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,4 +11,8 @@ end @safetestset "Layers" begin include("test_layers.jl") +end + +@safetestset "Wrappers" begin + include("test_wrappers.jl") end \ No newline at end of file diff --git a/test/test_wrappers.jl b/test/test_wrappers.jl new file mode 100644 index 0000000..93aa354 --- /dev/null +++ b/test/test_wrappers.jl @@ -0,0 +1,20 @@ +using RecurrentLayers +using Flux +using Test + +layers = [RNN, GRU, GRUv3, LSTM, MGU, LiGRU, RAN, LightRU, NAS, MUT1, MUT2, MUT3, +SCRN, PeepholeLSTM, FastRNN, FastGRNN] + +@testset "Sizes for StackedRNN with layer: $layer" for layer in layers + wrap = StackedRNN(layer, 2 => 4) + + inp = rand(Float32, 2, 3, 1) + output = wrap(inp) + @test output isa Array{Float32, 3} + @test size(output) == (4, 3, 1) + + inp = rand(Float32, 2, 3) + output = wrap(inp) + @test output isa Array{Float32, 2} + @test size(output) == (4, 3) +end \ No newline at end of file