diff --git a/Project.toml b/Project.toml index a052fd1..4c677cc 100644 --- a/Project.toml +++ b/Project.toml @@ -15,7 +15,8 @@ julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "Test", "SafeTestsets"] +test = ["Aqua", "Test", "SafeTestsets", "JET"] diff --git a/src/cells/fastrnn_cell.jl b/src/cells/fastrnn_cell.jl index 9f49341..3aa61f5 100644 --- a/src/cells/fastrnn_cell.jl +++ b/src/cells/fastrnn_cell.jl @@ -137,7 +137,7 @@ end function Base.show(io::IO, fastrnn::FastRNN) print(io, "FastRNN(", size(fastrnn.cell.Wi, 2), " => ", size(fastrnn.cell.Wi, 1)) - print(io, ", ", fastgrnn.cell.activation) + print(io, ", ", fastrnn.cell.activation) print(io, ")") end diff --git a/src/cells/rhn_cell.jl b/src/cells/rhn_cell.jl index 1ee74c9..b60a2f7 100644 --- a/src/cells/rhn_cell.jl +++ b/src/cells/rhn_cell.jl @@ -113,7 +113,7 @@ end function (rhn::RHNCell)(inp::AbstractArray, state::AbstractVecOrMat) - current_state = state + current_state = colify(state) for (i, layer) in enumerate(rhn.layers.layers) if i == 1 @@ -190,4 +190,11 @@ end function (rhn::RHN)(inp::AbstractArray, state::AbstractVecOrMat) @assert ndims(inp) == 2 || ndims(inp) == 3 return scan(rhn.cell, inp, state) +end + +function colify(x::AbstractArray) + # If x is already 2D (e.g. (N,1)), leave it. + # If x is 1D (N,), reshape to (N, 1). + ndims(x) == 1 && return reshape(x, (length(x), 1)) + return x end \ No newline at end of file diff --git a/src/wrappers/stackedrnn.jl b/src/wrappers/stackedrnn.jl index b200849..8761449 100644 --- a/src/wrappers/stackedrnn.jl +++ b/src/wrappers/stackedrnn.jl @@ -51,9 +51,11 @@ function StackedRNN(rlayer, (input_size, hidden_size)::Pair, args...; end function (stackedrnn::StackedRNN)(inp::AbstractArray) - for (idx,(layer, state)) in enumerate(zip(stackedrnn.layers, stackedrnn.states)) - inp = layer(inp, state) - if !(idx == length(stackedrnn.layers)) + @assert length(stackedrnn.layers) == length(stackedrnn.states) "Mismatch in layers vs. states length!" + @assert !isempty(stackedrnn.layers) "StackedRNN has no layers!" + for idx in eachindex(stackedrnn.layers) + inp = stackedrnn.layers[idx](inp, stackedrnn.states[idx]) + if idx != length(stackedrnn.layers) inp = stackedrnn.dropout(inp) end end diff --git a/test/qa.jl b/test/qa.jl index 736cbce..49ca321 100644 --- a/test/qa.jl +++ b/test/qa.jl @@ -1,4 +1,6 @@ using RecurrentLayers using Aqua +using JET Aqua.test_all(RecurrentLayers; ambiguities=false, deps_compat=(check_extras = false)) +JET.test_package(RecurrentLayers) \ No newline at end of file