Skip to content

Commit

Permalink
adding jet tests and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco committed Dec 24, 2024
1 parent bc0af85 commit b024c96
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 6 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ julia = "1.10"
[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test", "SafeTestsets"]
test = ["Aqua", "Test", "SafeTestsets", "JET"]
2 changes: 1 addition & 1 deletion src/cells/fastrnn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ end

function Base.show(io::IO, fastrnn::FastRNN)
print(io, "FastRNN(", size(fastrnn.cell.Wi, 2), " => ", size(fastrnn.cell.Wi, 1))
print(io, ", ", fastgrnn.cell.activation)
print(io, ", ", fastrnn.cell.activation)
print(io, ")")
end

Expand Down
9 changes: 8 additions & 1 deletion src/cells/rhn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ end

function (rhn::RHNCell)(inp::AbstractArray, state::AbstractVecOrMat)

current_state = state
current_state = colify(state)

for (i, layer) in enumerate(rhn.layers.layers)
if i == 1
Expand Down Expand Up @@ -190,4 +190,11 @@ end
function (rhn::RHN)(inp::AbstractArray, state::AbstractVecOrMat)
@assert ndims(inp) == 2 || ndims(inp) == 3
return scan(rhn.cell, inp, state)
end

function colify(x::AbstractArray)
# If x is already 2D (e.g. (N,1)), leave it.
# If x is 1D (N,), reshape to (N, 1).
ndims(x) == 1 && return reshape(x, (length(x), 1))
return x
end
8 changes: 5 additions & 3 deletions src/wrappers/stackedrnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@ function StackedRNN(rlayer, (input_size, hidden_size)::Pair, args...;
end

function (stackedrnn::StackedRNN)(inp::AbstractArray)
for (idx,(layer, state)) in enumerate(zip(stackedrnn.layers, stackedrnn.states))
inp = layer(inp, state)
if !(idx == length(stackedrnn.layers))
@assert length(stackedrnn.layers) == length(stackedrnn.states) "Mismatch in layers vs. states length!"
@assert !isempty(stackedrnn.layers) "StackedRNN has no layers!"
for idx in eachindex(stackedrnn.layers)
inp = stackedrnn.layers[idx](inp, stackedrnn.states[idx])
if idx != length(stackedrnn.layers)
inp = stackedrnn.dropout(inp)
end
end
Expand Down
2 changes: 2 additions & 0 deletions test/qa.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using RecurrentLayers
using Aqua
using JET

Aqua.test_all(RecurrentLayers; ambiguities=false, deps_compat=(check_extras = false))
JET.test_package(RecurrentLayers)

0 comments on commit b024c96

Please sign in to comment.