diff --git a/src/peepholelstm_cell.jl b/src/peepholelstm_cell.jl index 0898057..2bb5e9e 100644 --- a/src/peepholelstm_cell.jl +++ b/src/peepholelstm_cell.jl @@ -63,7 +63,7 @@ function PeepholeLSTMCell( end function (lstm::PeepholeLSTMCell)(inp::AbstractVecOrMat) - state = zeros_like(lstm, size(lstm.Wh, 2)) + state = zeros_like(inp, size(lstm.Wh, 2)) c_state = zeros_like(state) return lstm(inp, (state, c_state)) end @@ -75,7 +75,7 @@ function (lstm::PeepholeLSTMCell)(inp::AbstractVecOrMat, g = lstm.Wi * inp .+ lstm.Wh * c_state .+ b input, forget, cell, output = chunk(g, 4; dims = 1) new_cstate = @. sigmoid_fast(forget) * c_state + sigmoid_fast(input) * tanh_fast(cell) - new_state = @. sigmoid_fast(output) * tanh_fast(c′) + new_state = @. sigmoid_fast(output) * tanh_fast(new_cstate) return new_state, new_cstate end