From 7663a9ec9d8724ac4fa1fa2c712411435aeec5ed Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Tue, 3 Dec 2024 19:48:10 +0100 Subject: [PATCH] fixing rhn --- src/indrnn_cell.jl | 2 +- src/rhn_cell.jl | 2 +- test/test_cells.jl | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/indrnn_cell.jl b/src/indrnn_cell.jl index 1a15fd8..00400eb 100644 --- a/src/indrnn_cell.jl +++ b/src/indrnn_cell.jl @@ -54,7 +54,7 @@ end function (indrnn::IndRNNCell)(inp::AbstractVecOrMat, state::AbstractVecOrMat) _size_check(indrnn, inp, 1 => size(indrnn.Wi, 2)) σ = NNlib.fast_act(indrnn.σ, inp) - state = σ.(indrnn.Wi*inp .+ indrnn.u.*state .+ indrnn.b) + state = σ.(indrnn.Wi*inp .+ indrnn.u .* state .+ indrnn.b) return state end diff --git a/src/rhn_cell.jl b/src/rhn_cell.jl index 554e217..db61ab5 100644 --- a/src/rhn_cell.jl +++ b/src/rhn_cell.jl @@ -31,7 +31,7 @@ function (rhn::RHNCellUnit)(inp::AbstractVecOrMat, state) weight, bias = rhn.weight, rhn.bias #compute - pre_nonlin = weight * inp + bias + pre_nonlin = weight * inp .+ bias #split pre_h, pre_t, pre_c = chunk(pre_nonlin, 3, dims = 1) diff --git a/test/test_cells.jl b/test/test_cells.jl index a4e713e..be3e48a 100644 --- a/test/test_cells.jl +++ b/test/test_cells.jl @@ -61,9 +61,9 @@ end @test rnncell(inp) == rnncell(inp, zeros(Float32, 5)) ##TODO rhncell bias is bugged atm - #rnncell = RHNCell(3 => 5; bias=false) - #@test length(Flux.trainables(rnncell)) == 3 + rnncell = RHNCell(3 => 5; bias=false) + @test length(Flux.trainables(rnncell)) == 3 - #inp = rand(Float32, 3) - #@test rnncell(inp) == rnncell(inp, zeros(Float32, 5)) + inp = rand(Float32, 3) + @test rnncell(inp) == rnncell(inp, zeros(Float32, 5)) end \ No newline at end of file