Skip to content

Commit

Permalink
Merge pull request #45 from MartinuzziFrancesco/fm/mf
Browse files Browse the repository at this point in the history
Quality of life improvements
  • Loading branch information
MartinuzziFrancesco authored Jan 16, 2025
2 parents c1e466c + 0d33ad1 commit beead42
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 35 deletions.
7 changes: 5 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@ jobs:
fail-fast: false
matrix:
version:
- '1.10'
- '1.11'
- '1'
- 'lts'
- 'pre'
os:
- ubuntu-latest
- windows-latest
- macos-latest
arch:
- x64
steps:
Expand Down
42 changes: 42 additions & 0 deletions .github/workflows/FormatCheck.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
name: format-check

on:
push:
branches:
- 'main'
- 'release-'
tags: '*'
pull_request:

jobs:
build:
runs-on: ${{ matrix.os }}
strategy:
matrix:
julia-version: [1]
julia-arch: [x86]
os: [ubuntu-latest]
steps:
- uses: julia-actions/setup-julia@latest
with:
version: ${{ matrix.julia-version }}

- uses: actions/checkout@v4
- name: Install JuliaFormatter and format
# This will use the latest version by default but you can set the version like so:
#
# julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter", version="0.13.0"))'
run: |
julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))'
julia -e 'using JuliaFormatter; format(".", verbose=true)'
- name: Format check
run: |
julia -e '
out = Cmd(`git diff --name-only`) |> read |> String
if out == ""
exit(0)
else
@error "Some files have not been formatted !!!"
write(stdout, out)
exit(1)
end'
7 changes: 3 additions & 4 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
using RecurrentLayers
using Documenter, DocumenterInterLinks
using RecurrentLayers, Documenter, DocumenterInterLinks
include("pages.jl")

DocMeta.setdocmeta!(
RecurrentLayers, :DocTestSetup, :(using RecurrentLayers); recursive=true)
mathengine = Documenter.MathJax()

links = InterLinks(
Expand All @@ -14,6 +11,8 @@ makedocs(;
modules=[RecurrentLayers],
authors="Francesco Martinuzzi",
sitename="RecurrentLayers.jl",
clean=true, doctest=true,
linkcheck=true,
format=Documenter.HTML(;
mathengine,
assets=["assets/favicon.ico"],
Expand Down
5 changes: 2 additions & 3 deletions src/RecurrentLayers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ module RecurrentLayers

using Compat: @compat
using Flux: _size_check, _match_eltype, chunk, create_bias, zeros_like, glorot_uniform,
scan, @layer, default_rng, Chain, Dropout
scan, @layer, default_rng, Chain, Dropout, sigmoid_fast, tanh_fast, relu
import Flux: initialstates
import Functors: functor
#to remove
using NNlib: fast_act, sigmoid_fast, tanh_fast, relu
using NNlib: fast_act

export MGUCell, LiGRUCell, IndRNNCell, RANCell, LightRUCell, RHNCell,
RHNCellUnit, NASCell, MUT1Cell, MUT2Cell, MUT3Cell, SCRNCell, PeepholeLSTMCell,
Expand Down
41 changes: 30 additions & 11 deletions src/wrappers/stackedrnn.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,49 @@
# based on https://fluxml.ai/Flux.jl/stable/guide/models/recurrence/
struct StackedRNN{L, D, S}
layers::L
dropout::D
states::S
end

@layer StackedRNN trainable=(layers)

@doc raw"""
StackedRNN(rlayer, (input_size, hidden_size), args...;
num_layers = 1, kwargs...)
num_layers = 1, dropout = 0.0, kwargs...)
Constructs a stack of recurrent layers given the recurrent layer type.
Arguments:
# Arguments:
- `rlayer`: Any recurrent layer such as [MGU](@ref), [RHN](@ref), etc... or
[`Flux.RNN`](@extref), [`Flux.LSTM`](@extref), etc.
- `input_size`: Defines the input dimension for the first layer.
- `hidden_size`: defines the dimension of the hidden layer.
- `num_layers`: The number of layers to stack. Default is 1.
- `dropout`: Value of dropout to apply between recurrent layers.
Default is 0.0.
- `args...`: Additional positional arguments passed to the recurrent layer.
# Keyword arguments
- `kwargs...`: Additional keyword arguments passed to the recurrent layers.
Returns:
A `StackedRNN` instance containing the specified number of RNN layers and their initial states.
# Examples
```jldoctest
julia> using RecurrentLayers
julia> stac_rnn = StackedRNN(MGU, (3=>5); num_layers = 4)
StackedRNN(
[
MGU(3 => 10), # 90 parameters
MGU(5 => 10), # 110 parameters
MGU(5 => 10), # 110 parameters
MGU(5 => 10), # 110 parameters
],
) # Total: 12 trainable arrays, 420 parameters,
# plus 4 non-trainable, 20 parameters, summarysize 2.711 KiB.
```
"""
struct StackedRNN{L, D, S}
layers::L
dropout::D
states::S
end

@layer StackedRNN trainable=(layers)

function StackedRNN(rlayer, (input_size, hidden_size)::Pair{<:Int, <:Int}, args...;
num_layers::Int=1, dropout::Number=0.0, dims=:,
active::Union{Bool, Nothing}=nothing, rng=default_rng(), kwargs...)
Expand Down
4 changes: 1 addition & 3 deletions test/qa.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
using RecurrentLayers
using Aqua
using JET
using RecurrentLayers, Aqua, JET

Aqua.test_all(RecurrentLayers; ambiguities=false, deps_compat=(check_extras = false))
JET.test_package(RecurrentLayers)
3 changes: 1 addition & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using SafeTestsets
using Test
using SafeTestsets, Test

@safetestset "Quality Assurance" begin
include("qa.jl")
Expand Down
4 changes: 1 addition & 3 deletions test/test_cells.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
using RecurrentLayers
using Flux
using Test
using RecurrentLayers, Flux, Test

#cells returning a single hidden state
single_cells = [MGUCell, LiGRUCell, IndRNNCell,
Expand Down
5 changes: 1 addition & 4 deletions test/test_layers.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
using RecurrentLayers
using Flux
using Test

using RecurrentLayers, Flux, Test
import Flux: initialstates

layers = [MGU, LiGRU, RAN, LightRU, NAS, MUT1, MUT2, MUT3,
Expand Down
4 changes: 1 addition & 3 deletions test/test_wrappers.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
using RecurrentLayers
using Flux
using Test
using RecurrentLayers, Flux, Test

layers = [RNN, GRU, GRUv3, LSTM, MGU, LiGRU, RAN, LightRU, NAS, MUT1, MUT2, MUT3,
SCRN, PeepholeLSTM, FastRNN, FastGRNN]
Expand Down

0 comments on commit beead42

Please sign in to comment.