Skip to content

Commit

Permalink
some more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco committed Dec 2, 2024
1 parent 28d6646 commit 45d9333
Showing 1 changed file with 30 additions and 1 deletion.
31 changes: 30 additions & 1 deletion test/test_cells.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ double_cells = [RANCell, NASCell, PeepholeLSTMCell]
#cells with a little more complexity to them
different_cells = [SCRNCell, RHNCell]

@testset "Signle return cell: cell = $cell" for cell in single_cells
@testset "Single return cell: cell = $cell" for cell in single_cells
rnncell = cell(3 => 5)
@test length(Flux.trainables(rnncell)) == 3

Expand All @@ -37,4 +37,33 @@ end

inp = rand(Float32, 3)
@test rnncell(inp) == rnncell(inp, (zeros(Float32, 5), zeros(Float32, 5)))
end

@testset "SCRNCell" begin
rnncell = SCRNCell(3 => 5)
@test length(Flux.trainables(rnncell)) == 4

inp = rand(Float32, 3)
@test rnncell(inp) == rnncell(inp, (zeros(Float32, 5), zeros(Float32, 5)))

rnncell = SCRNCell(3 => 5; bias=false)
@test length(Flux.trainables(rnncell)) == 3

inp = rand(Float32, 3)
@test rnncell(inp) == rnncell(inp, (zeros(Float32, 5), zeros(Float32, 5)))
end

@testset "RHNCell" begin
rnncell = RHNCell(3 => 5)
@test length(Flux.trainables(rnncell)) == 6

inp = rand(Float32, 3)
@test rnncell(inp) == rnncell(inp, zeros(Float32, 5))

##TODO rhncell bias is bugged atm
#rnncell = RHNCell(3 => 5; bias=false)
#@test length(Flux.trainables(rnncell)) == 3

#inp = rand(Float32, 3)
#@test rnncell(inp) == rnncell(inp, zeros(Float32, 5))
end

0 comments on commit 45d9333

Please sign in to comment.