diff --git a/src/RecurrentLayers.jl b/src/RecurrentLayers.jl index 10f5aa7..dec1ab4 100644 --- a/src/RecurrentLayers.jl +++ b/src/RecurrentLayers.jl @@ -35,7 +35,7 @@ function (rlayer::AbstractRecurrentLayer)(inp::AbstractVecOrMat) return rlayer(inp, state) end -function (rlayer::AbstractRecurrentLayer)(inp, state) +function (rlayer::AbstractRecurrentLayer)(inp::AbstractArray, state::AbstractVecOrMat) @assert ndims(inp) == 2 || ndims(inp) == 3 return scan(rlayer.cell, inp, state) end @@ -47,7 +47,6 @@ export MGU, LiGRU, IndRNN, RAN, LightRU, NAS, RHN, MUT1, MUT2, MUT3, SCRN, PeepholeLSTM, FastRNN, FastGRNN -#TODO add double bias include("mgu_cell.jl") include("ligru_cell.jl") include("indrnn_cell.jl")