diff --git a/test/test_cells.jl b/test/test_cells.jl index a6dd7cc..a4e713e 100644 --- a/test/test_cells.jl +++ b/test/test_cells.jl @@ -11,7 +11,7 @@ double_cells = [RANCell, NASCell, PeepholeLSTMCell] #cells with a little more complexity to them different_cells = [SCRNCell, RHNCell] -@testset "Signle return cell: cell = $cell" for cell in single_cells +@testset "Single return cell: cell = $cell" for cell in single_cells rnncell = cell(3 => 5) @test length(Flux.trainables(rnncell)) == 3 @@ -37,4 +37,33 @@ end inp = rand(Float32, 3) @test rnncell(inp) == rnncell(inp, (zeros(Float32, 5), zeros(Float32, 5))) +end + +@testset "SCRNCell" begin + rnncell = SCRNCell(3 => 5) + @test length(Flux.trainables(rnncell)) == 4 + + inp = rand(Float32, 3) + @test rnncell(inp) == rnncell(inp, (zeros(Float32, 5), zeros(Float32, 5))) + + rnncell = SCRNCell(3 => 5; bias=false) + @test length(Flux.trainables(rnncell)) == 3 + + inp = rand(Float32, 3) + @test rnncell(inp) == rnncell(inp, (zeros(Float32, 5), zeros(Float32, 5))) +end + +@testset "RHNCell" begin + rnncell = RHNCell(3 => 5) + @test length(Flux.trainables(rnncell)) == 6 + + inp = rand(Float32, 3) + @test rnncell(inp) == rnncell(inp, zeros(Float32, 5)) + + ##TODO rhncell bias is bugged atm + #rnncell = RHNCell(3 => 5; bias=false) + #@test length(Flux.trainables(rnncell)) == 3 + + #inp = rand(Float32, 3) + #@test rnncell(inp) == rnncell(inp, zeros(Float32, 5)) end \ No newline at end of file