Skip to content

Commit

Permalink
start of mut
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco committed Nov 3, 2024
1 parent 9eb926a commit 69023bd
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Currently available layers and work in progress in the short term:
- [x] Recurrent highway network (RHN) [arixv](https://arxiv.org/pdf/1607.03474)
- [x] Light recurrent unit (LightRU) [pub](https://www.mdpi.com/2079-9292/13/16/3204)
- [x] Neural architecture search unit (NAS) [arxiv](https://arxiv.org/abs/1611.01578)
- [x] Evolving recurrent neural networks (MUT1/2/3) [pub](https://proceedings.mlr.press/v37/jozefowicz15.pdf)
- [ ] Minimal gated recurrent unit (minGRU) and minimal long short term memory (minLSTM) [arxiv](https://arxiv.org/abs/2410.01201)

## Installation
Expand Down
3 changes: 2 additions & 1 deletion src/RecurrentLayers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import Flux: _size_check, _match_eltype, chunk, create_bias, zeros_like
import Flux: glorot_uniform

export MGUCell, LiGRUCell, IndRNNCell, RANCell, LightRUCell, RHNCell,
RHNCellUnit, NASCell
RHNCellUnit, NASCell, MUT1Cell
export MGU, LiGRU, IndRNN, RAN, LightRU, NAS, RHN

#TODO add double bias
Expand All @@ -16,5 +16,6 @@ include("ran_cell.jl")
include("lightru_cell.jl")
include("rhn_cell.jl")
include("nas_cell.jl")
include("mut_cell.jl")

end #module
47 changes: 47 additions & 0 deletions src/mut_cell.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#https://proceedings.mlr.press/v37/jozefowicz15.pdf
struct MUT1Cell{I, H, V}
Wi::I
Wh::H
bias::V
end

Flux.@layer MUT1Cell

"""
MUT1Cell((in, out)::Pair; init = glorot_uniform, bias = true)
"""
function MUT1Cell((in, out)::Pair;
init = glorot_uniform,
bias = true)

Wi = init(out * 2, in)
Wh = init(out * 2, out)
b = create_bias(Wi, bias, 3 * out)

return MUT1Cell(Wi, Wh, b)
end

function (mut::MUT1Cell)(inp::AbstractVecOrMat)
state = zeros_like(inp, size(mut.Wh, 2))
return mut(inp, state)
end

function (mut::MUT1Cell)(inp::AbstractVecOrMat, state)
_size_check(mut, inp, 1 => size(mut.Wi,2))
Wi, Wh, b = mut.Wi, mut.Wh, mut.bias
#split
gxs = chunk(Wi * inp, 2, dims=1)
ghs = chunk(Wh, 2, dims=1)
bs = chunk(b, 3, dims=1)

forget_gate = sigmoid_fast.(gxs[1] .+ bs[1])
reset_gate = sigmoid_fast.(gxs[1] .+ ghs[1]*state .+ bs[2])
candidate_state = tanh_fast.(
ghs[2] * (forget_gate .* state) .+ tanh_fast(inp) + bs[3]
)
new_state = candidate_state .* forget_gate .+ state .* (1 .- forget_gate)
return new_state
end

Base.show(io::IO, mut::MUT1Cell) =
print(io, "MUT1Cell(", size(mut.Wi, 2), " => ", size(mut.Wi, 1) ÷ 2, ")")

0 comments on commit 69023bd

Please sign in to comment.