diff --git a/README.md b/README.md index 5d3d734..0a8157c 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ Currently available layers and work in progress in the short term: - [x] Recurrent highway network (RHN) [arixv](https://arxiv.org/pdf/1607.03474) - [x] Light recurrent unit (LightRU) [pub](https://www.mdpi.com/2079-9292/13/16/3204) - [x] Neural architecture search unit (NAS) [arxiv](https://arxiv.org/abs/1611.01578) + - [x] Evolving recurrent neural networks (MUT1/2/3) [pub](https://proceedings.mlr.press/v37/jozefowicz15.pdf) - [ ] Minimal gated recurrent unit (minGRU) and minimal long short term memory (minLSTM) [arxiv](https://arxiv.org/abs/2410.01201) ## Installation diff --git a/src/RecurrentLayers.jl b/src/RecurrentLayers.jl index 0102af5..18c1089 100644 --- a/src/RecurrentLayers.jl +++ b/src/RecurrentLayers.jl @@ -5,7 +5,7 @@ import Flux: _size_check, _match_eltype, chunk, create_bias, zeros_like import Flux: glorot_uniform export MGUCell, LiGRUCell, IndRNNCell, RANCell, LightRUCell, RHNCell, -RHNCellUnit, NASCell +RHNCellUnit, NASCell, MUT1Cell export MGU, LiGRU, IndRNN, RAN, LightRU, NAS, RHN #TODO add double bias @@ -16,5 +16,6 @@ include("ran_cell.jl") include("lightru_cell.jl") include("rhn_cell.jl") include("nas_cell.jl") +include("mut_cell.jl") end #module diff --git a/src/mut_cell.jl b/src/mut_cell.jl new file mode 100644 index 0000000..97cfb3d --- /dev/null +++ b/src/mut_cell.jl @@ -0,0 +1,47 @@ +#https://proceedings.mlr.press/v37/jozefowicz15.pdf +struct MUT1Cell{I, H, V} + Wi::I + Wh::H + bias::V +end + +Flux.@layer MUT1Cell + +""" + MUT1Cell((in, out)::Pair; init = glorot_uniform, bias = true) +""" +function MUT1Cell((in, out)::Pair; + init = glorot_uniform, + bias = true) + + Wi = init(out * 2, in) + Wh = init(out * 2, out) + b = create_bias(Wi, bias, 3 * out) + + return MUT1Cell(Wi, Wh, b) +end + +function (mut::MUT1Cell)(inp::AbstractVecOrMat) + state = zeros_like(inp, size(mut.Wh, 2)) + return mut(inp, state) +end + +function (mut::MUT1Cell)(inp::AbstractVecOrMat, state) + _size_check(mut, inp, 1 => size(mut.Wi,2)) + Wi, Wh, b = mut.Wi, mut.Wh, mut.bias + #split + gxs = chunk(Wi * inp, 2, dims=1) + ghs = chunk(Wh, 2, dims=1) + bs = chunk(b, 3, dims=1) + + forget_gate = sigmoid_fast.(gxs[1] .+ bs[1]) + reset_gate = sigmoid_fast.(gxs[1] .+ ghs[1]*state .+ bs[2]) + candidate_state = tanh_fast.( + ghs[2] * (forget_gate .* state) .+ tanh_fast(inp) + bs[3] + ) + new_state = candidate_state .* forget_gate .+ state .* (1 .- forget_gate) + return new_state +end + +Base.show(io::IO, mut::MUT1Cell) = + print(io, "MUT1Cell(", size(mut.Wi, 2), " => ", size(mut.Wi, 1) รท 2, ")") \ No newline at end of file