From 65dbfe17153928467af1f3b3f57512dd6bac6044 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sun, 3 Nov 2024 16:16:06 +0100 Subject: [PATCH] some renaming --- src/indrnn_cell.jl | 14 ++++++++++---- src/lightru_cell.jl | 14 ++++++++++---- src/ligru_cell.jl | 12 ++++++++---- src/mgu_cell.jl | 12 ++++++++---- src/mut_cell.jl | 36 ++++++++++++++++++++++++------------ src/nas_cell.jl | 14 ++++++++++---- src/ran_cell.jl | 14 ++++++++++---- src/rhn_cell.jl | 10 +++++++--- src/scrn_cell.jl | 15 ++++++++++----- src/sru_cell.jl | 11 +++++++---- 10 files changed, 104 insertions(+), 48 deletions(-) diff --git a/src/indrnn_cell.jl b/src/indrnn_cell.jl index b62fd6c..8c2fe11 100644 --- a/src/indrnn_cell.jl +++ b/src/indrnn_cell.jl @@ -9,11 +9,17 @@ end Flux.@layer IndRNNCell """ - IndRNNCell((in, out)::Pair, σ=relu; init = glorot_uniform, bias = true) + function IndRNNCell((in, out)::Pair, σ=relu; + kernel_init = glorot_uniform, + recurrent_kernel_init = glorot_uniform, + bias = true) """ -function IndRNNCell((in, out)::Pair, σ=relu; init = glorot_uniform, bias = true) - Wi = init(out, in) - u = init(out) +function IndRNNCell((in, out)::Pair, σ=relu; + kernel_init = glorot_uniform, + recurrent_kernel_init = glorot_uniform, + bias = true) + Wi = kernel_init(out, in) + u = recrrent_kernel_init(out) b = create_bias(Wi, bias, size(Wi, 1)) return IndRNNCell(σ, Wi, u, b) end diff --git a/src/lightru_cell.jl b/src/lightru_cell.jl index 8948575..b64e8df 100644 --- a/src/lightru_cell.jl +++ b/src/lightru_cell.jl @@ -8,11 +8,17 @@ end Flux.@layer LightRUCell """ - LightRUCell((in, out)::Pair, σ=tanh; init = glorot_uniform, bias = true) + LightRUCell((in, out)::Pair, σ=tanh; + kernel_init = glorot_uniform, + recurrent_kernel_init = glorot_uniform, + bias = true) """ -function LightRUCell((in, out)::Pair, σ=tanh; init = glorot_uniform, bias = true) - Wi = init(2 * out, in) - Wh = init(out, out) +function LightRUCell((in, out)::Pair, σ=tanh; + kernel_init = glorot_uniform, + recurrent_kernel_init = glorot_uniform, + bias = true) + Wi = kernel_init(2 * out, in) + Wh = recurrent_kernel_init(out, out) b = create_bias(Wi, bias, size(Wh, 1)) return LightRUCell(Wi, Wh, b) diff --git a/src/ligru_cell.jl b/src/ligru_cell.jl index 35c2e90..4544075 100644 --- a/src/ligru_cell.jl +++ b/src/ligru_cell.jl @@ -8,14 +8,18 @@ end Flux.@layer LiGRUCell """ - LiGRUCell((in, out)::Pair; init = glorot_uniform, bias = true) + LiGRUCell((in, out)::Pair; + kernel_init = glorot_uniform, + recurrent_kernel_init = glorot_uniform, + bias = true) """ function LiGRUCell((in, out)::Pair; - init = glorot_uniform, + kernel_init = glorot_uniform, + recurrent_kernel_init = glorot_uniform, bias = true) - Wi = init(out * 2, in) - Wh = init(out * 2, out) + Wi = kernel_init(out * 2, in) + Wh = recurrent_kernel_init(out * 2, out) b = create_bias(Wi, bias, size(Wi, 1)) return LiGRUCell(Wi, Wh, b) diff --git a/src/mgu_cell.jl b/src/mgu_cell.jl index 32f7e64..85f33ae 100644 --- a/src/mgu_cell.jl +++ b/src/mgu_cell.jl @@ -8,14 +8,18 @@ end Flux.@layer MGUCell """ - MGUCell((in, out)::Pair; init = glorot_uniform, bias = true) + MGUCell((in, out)::Pair; + kernel_init = glorot_uniform, + recurrent_kernel_init = glorot_uniform, + bias = true) """ function MGUCell((in, out)::Pair; - init = glorot_uniform, + kernel_init = glorot_uniform, + recurrent_kernel_init = glorot_uniform, bias = true) - Wi = init(out * 2, in) - Wh = init(out * 2, out) + Wi = kernel_init(out * 2, in) + Wh = recurrent_kernel_init(out * 2, out) b = create_bias(Wi, bias, size(Wi, 1)) return MGUCell(Wi, Wh, b) diff --git a/src/mut_cell.jl b/src/mut_cell.jl index 719cf60..5e7f529 100644 --- a/src/mut_cell.jl +++ b/src/mut_cell.jl @@ -8,14 +8,18 @@ end Flux.@layer MUT1Cell """ - MUT1Cell((in, out)::Pair; init = glorot_uniform, bias = true) + MUT1Cell((in, out)::Pair; + kernel_init = glorot_uniform, + recurrent_kernel_init = glorot_uniform, + bias = true) """ function MUT1Cell((in, out)::Pair; - init = glorot_uniform, + kernel_init = glorot_uniform, + recurrent_kernel_init = glorot_uniform, bias = true) - Wi = init(out * 3, in) - Wh = init(out * 2, out) + Wi = kernel_init(out * 3, in) + Wh = recurrent_kernel_init(out * 2, out) b = create_bias(Wi, bias, 3 * out) return MUT1Cell(Wi, Wh, b) @@ -55,14 +59,18 @@ end Flux.@layer MUT2Cell """ - MUT2Cell((in, out)::Pair; init = glorot_uniform, bias = true) + MUT2Cell((in, out)::Pair; + kernel_init = glorot_uniform, + recurrent_kernel_init = glorot_uniform, + bias = true) """ function MUT2Cell((in, out)::Pair; - init = glorot_uniform, + kernel_init = glorot_uniform, + recurrent_kernel_init = glorot_uniform, bias = true) - Wi = init(out * 3, in) - Wh = init(out * 3, out) + Wi = kernel_init(out * 3, in) + Wh = recurrent_kernel_init(out * 3, out) b = create_bias(Wi, bias, 3 * out) return MUT2Cell(Wi, Wh, b) @@ -102,14 +110,18 @@ end Flux.@layer MUT3Cell """ - MUT3Cell((in, out)::Pair; init = glorot_uniform, bias = true) + MUT3Cell((in, out)::Pair; + kernel_init = glorot_uniform, + recurrent_kernel_init = glorot_uniform, + bias = true) """ function MUT3Cell((in, out)::Pair; - init = glorot_uniform, + kernel_init = glorot_uniform, + recurrent_kernel_init = glorot_uniform, bias = true) - Wi = init(out * 3, in) - Wh = init(out * 3, out) + Wi = kernel_init(out * 3, in) + Wh = recurrent_kernel_init(out * 3, out) b = create_bias(Wi, bias, 3 * out) return MUT3Cell(Wi, Wh, b) diff --git a/src/nas_cell.jl b/src/nas_cell.jl index a8d53c2..b9d8cd6 100644 --- a/src/nas_cell.jl +++ b/src/nas_cell.jl @@ -32,11 +32,17 @@ end Flux.@layer NASCell """ - NASCell((in, out)::Pair; init = glorot_uniform, bias = true) + NASCell((in, out)::Pair; + kernel_init = glorot_uniform, + recurrent_kernel_init = glorot_uniform, + bias = true) """ -function NASCell((in, out)::Pair; init = glorot_uniform, bias = true) - Wi = init(8 * out, in) - Wh = init(8 * out, out) +function NASCell((in, out)::Pair; + kernel_init = glorot_uniform, + recurrent_kernel_init = glorot_uniform, + bias = true) + Wi = kernel_init(8 * out, in) + Wh = recurrent_kernel_init(8 * out, out) b = create_bias(Wi, bias, size(Wh, 1)) return NASCell(Wi, Wh, b) end diff --git a/src/ran_cell.jl b/src/ran_cell.jl index 6c0be1f..a69d813 100644 --- a/src/ran_cell.jl +++ b/src/ran_cell.jl @@ -9,7 +9,10 @@ Flux.@layer RANCell """ - RANCell(in => out; init = glorot_uniform, bias = true) + RANCell((in, out)::Pair; + kernel_init = glorot_uniform, + recurrent_kernel_init = glorot_uniform, + bias = true) The `RANCell`, introduced in [this paper](https://arxiv.org/pdf/1705.07393), is a recurrent cell layer which provides additional memory through the @@ -51,9 +54,12 @@ result = rancell(inp) result_state = rancell(inp, (state, c_state)) ``` """ -function RANCell((in, out)::Pair; init = glorot_uniform, bias = true) - Wi = init(3 * out, in) - Wh = init(2 * out, out) +function RANCell((in, out)::Pair; + kernel_init = glorot_uniform, + recurrent_kernel_init = glorot_uniform, + bias = true) + Wi = kernel_init(3 * out, in) + Wh = recurrent_kernel_init(2 * out, out) b = create_bias(Wi, bias, size(Wh, 1)) return RANCell(Wi, Wh, b) end diff --git a/src/rhn_cell.jl b/src/rhn_cell.jl index 6702183..d4c3cd3 100644 --- a/src/rhn_cell.jl +++ b/src/rhn_cell.jl @@ -9,10 +9,14 @@ end Flux.@layer RHNCellUnit """ - RHNCellUnit((in, out)::Pair; init = glorot_uniform, bias = true) + RHNCellUnit((in, out)::Pair; + kernel_init = glorot_uniform, + bias = true) """ -function RHNCellUnit((in, out)::Pair; init = glorot_uniform, bias = true) - weight = init(3 * out, in) +function RHNCellUnit((in, out)::Pair; + kernel_init = glorot_uniform, + bias = true) + weight = kernel_init(3 * out, in) b = create_bias(weight, bias, size(weight, 1)) return RHNCellUnit(weight, b) end diff --git a/src/scrn_cell.jl b/src/scrn_cell.jl index bf35afc..ab06e8e 100644 --- a/src/scrn_cell.jl +++ b/src/scrn_cell.jl @@ -11,16 +11,21 @@ Flux.@layer SCRNCell """ - SCRNCell(in => out; init = glorot_uniform, bias = true) + function SCRNCell((in, out)::Pair; + kernel_init = glorot_uniform, + recurrent_kernel_init = glorot_uniform, + bias = true, + alpha = 0.0) """ function SCRNCell((in, out)::Pair; - init = glorot_uniform, + kernel_init = glorot_uniform, + recurrent_kernel_init = glorot_uniform, bias = true, alpha = 0.0) - Wi = init(2 * out, in) - Wh = init(2 * out, out) - Wc = init(2 * out, out) + Wi = kernel_init(2 * out, in) + Wh = recurrent_kernel_init(2 * out, out) + Wc = recurrent_kernel_init(2 * out, out) b = create_bias(Wi, bias, size(Wh, 1)) return SCRNCell(Wi, Wh, Wc, b, alpha) end diff --git a/src/sru_cell.jl b/src/sru_cell.jl index ae56c20..57c65e6 100644 --- a/src/sru_cell.jl +++ b/src/sru_cell.jl @@ -8,10 +8,13 @@ end Flux.@layer SRUCell -function SRUCell((in, out)::Pair, σ=tanh; init = glorot_uniform, bias = true) - Wi = init(2 * out, in) - Wh = init(2 * out, out) - v = init(2 * out) +function SRUCell((in, out)::Pair, σ=tanh; + kernel_init = glorot_uniform, + recurrent_kernel_init = glorot_uniform, + bias = true) + Wi = kernel_init(2 * out, in) + Wh = recurrent_kernel_init(2 * out, out) + v = kernel_init(2 * out) b = create_bias(Wi, bias, size(Wh, 1)) return SRUCell(Wi, Wh, v, b)