Skip to content

Commit

Permalink
Merge pull request #14 from MartinuzziFrancesco/fm/rhn
Browse files Browse the repository at this point in the history
Fix RHN
  • Loading branch information
MartinuzziFrancesco authored Dec 3, 2024
2 parents 45d9333 + 7663a9e commit 31fceed
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/indrnn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ end
function (indrnn::IndRNNCell)(inp::AbstractVecOrMat, state::AbstractVecOrMat)
_size_check(indrnn, inp, 1 => size(indrnn.Wi, 2))
σ = NNlib.fast_act(indrnn.σ, inp)
state = σ.(indrnn.Wi*inp .+ indrnn.u.*state .+ indrnn.b)
state = σ.(indrnn.Wi*inp .+ indrnn.u .* state .+ indrnn.b)
return state
end

Expand Down
2 changes: 1 addition & 1 deletion src/rhn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ function (rhn::RHNCellUnit)(inp::AbstractVecOrMat, state)
weight, bias = rhn.weight, rhn.bias

#compute
pre_nonlin = weight * inp + bias
pre_nonlin = weight * inp .+ bias

#split
pre_h, pre_t, pre_c = chunk(pre_nonlin, 3, dims = 1)
Expand Down
8 changes: 4 additions & 4 deletions test/test_cells.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ end
@test rnncell(inp) == rnncell(inp, zeros(Float32, 5))

##TODO rhncell bias is bugged atm
#rnncell = RHNCell(3 => 5; bias=false)
#@test length(Flux.trainables(rnncell)) == 3
rnncell = RHNCell(3 => 5; bias=false)
@test length(Flux.trainables(rnncell)) == 3

#inp = rand(Float32, 3)
#@test rnncell(inp) == rnncell(inp, zeros(Float32, 5))
inp = rand(Float32, 3)
@test rnncell(inp) == rnncell(inp, zeros(Float32, 5))
end

0 comments on commit 31fceed

Please sign in to comment.